Compare commits

..

22 Commits

Author SHA1 Message Date
610f9b437d fix another bug 2025-11-14 06:44:20 +00:00
7d0eb9b4f6 fix lint issue 2025-11-14 06:44:20 +00:00
af6ae22dbd fix lint issue 2025-11-14 06:44:20 +00:00
e3afc32110 revert the change 2025-11-14 06:44:20 +00:00
15aa7e01a9 fix the rebase issue 2025-11-14 06:44:20 +00:00
fc5133bacb fix lint issue 2025-11-14 06:44:20 +00:00
1a49d0cda4 update according to review 2025-11-14 06:44:20 +00:00
e9a3814dea revert the change that already in other's pr 2025-11-14 06:44:20 +00:00
2ace9e465a revert change of case 2025-11-14 06:44:20 +00:00
d990b72872 update hpu to acc 2025-11-14 06:44:20 +00:00
a8243bd1d4 update 2025-11-14 06:44:20 +00:00
1ccc757cac skip failed case 2025-11-14 06:44:20 +00:00
2abf4ecf2f port distributed tensor case for Intel GPU 2025-11-14 06:44:20 +00:00
ff3e2942b4 update according to review 2025-11-14 06:44:19 +00:00
a81e5177de revert change of case 2025-11-14 06:44:19 +00:00
f02dba7893 fix a bug 2025-11-14 06:44:19 +00:00
09abf0ceff update hpu to acc 2025-11-14 06:44:19 +00:00
b4d23566db remove redundant skipper 2025-11-14 06:44:19 +00:00
20ca3c48de update 2025-11-14 06:44:19 +00:00
d83d25dee4 skip failed case 2025-11-14 06:44:19 +00:00
528d3fc4ce enable for xpu 2025-11-14 06:44:19 +00:00
fd178b2e17 port distributed tensor case for Intel GPU 2025-11-14 06:44:19 +00:00
55 changed files with 1289 additions and 1323 deletions

View File

@ -245,9 +245,6 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept { size_t weak_use_count() const noexcept {
return impl_.weak_use_count(); return impl_.weak_use_count();
} }
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const; std::string toString() const;

View File

@ -223,62 +223,6 @@ CONVERT_FROM_BF16_TEMPLATE(double)
CONVERT_FROM_BF16_TEMPLATE(float16_t) CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif #endif
#ifdef __ARM_FEATURE_BF16
// clang-[17, 20] crashes when autovectorizing static cast to bf16
// Below is a workaround to have some vectorization
// Works decently well for smaller int types
template <typename from_type>
inline void convertToBf16Impl(
const from_type* __restrict src,
c10::BFloat16* __restrict dst,
uint64_t n) {
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
uint64_t loopBound = n - (n % 16);
uint64_t i = 0;
for (; i < loopBound; i += 16) {
float32x4_t a, b, c, d;
a[0] = static_cast<float>(src[i]);
a[1] = static_cast<float>(src[i + 1]);
a[2] = static_cast<float>(src[i + 2]);
a[3] = static_cast<float>(src[i + 3]);
b[0] = static_cast<float>(src[i + 4]);
b[1] = static_cast<float>(src[i + 5]);
b[2] = static_cast<float>(src[i + 6]);
b[3] = static_cast<float>(src[i + 7]);
c[0] = static_cast<float>(src[i + 8]);
c[1] = static_cast<float>(src[i + 9]);
c[2] = static_cast<float>(src[i + 10]);
c[3] = static_cast<float>(src[i + 11]);
d[0] = static_cast<float>(src[i + 12]);
d[1] = static_cast<float>(src[i + 13]);
d[2] = static_cast<float>(src[i + 14]);
d[3] = static_cast<float>(src[i + 15]);
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
}
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
for (; i < n; i++) {
float a = static_cast<float>(src[i]);
dstPtr[i] = vcvth_bf16_f32(a);
}
}
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
template <> \
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
return convertToBf16Impl<from_type>(src, dst, n); \
}
CONVERT_TO_BF16_TEMPLATE(uint8_t)
CONVERT_TO_BF16_TEMPLATE(int8_t)
CONVERT_TO_BF16_TEMPLATE(int16_t)
CONVERT_TO_BF16_TEMPLATE(int32_t)
#endif
inline void convertBoolToBfloat16Impl( inline void convertBoolToBfloat16Impl(
const bool* __restrict src, const bool* __restrict src,
c10::BFloat16* __restrict dst, c10::BFloat16* __restrict dst,

View File

@ -10,13 +10,6 @@
... ...
} }
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{ {
Cond_cuda Cond_cuda
Memcheck:Cond Memcheck:Cond

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_); (*other.pyinterpreter_)->incref(other.data_);
} }
if (data_ != nullptr) { if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_); (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
} }
data_ = other.data_; data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_; pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() { ~SafePyObject() {
if (data_ != nullptr) { if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_); (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
} }
} }

View File

@ -48,30 +48,6 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
} }
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) { void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification. // Allowlist verification.
// Only if the devicetype is in the allowlist, // Only if the devicetype is in the allowlist,

View File

@ -105,12 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear(); data_ptr_.clear();
} }
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const { size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive // OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_); TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -376,18 +370,4 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable, bool resizable,
std::optional<at::Device> device_opt); std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10 } // namespace c10

View File

@ -277,6 +277,7 @@ void TensorImpl::release_resources() {
if (storage_) { if (storage_) {
storage_ = {}; storage_ = {};
} }
pyobj_slot_.maybe_destroy_pyobj();
} }
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -988,30 +989,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
} }
} }
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl { namespace impl {
namespace { namespace {

View File

@ -2178,12 +2178,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_; return &pyobj_slot_;
} }
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private: private:
// See NOTE [std::optional operator usage in CUDA] // See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until // We probably don't want to expose this publicly until
@ -3085,19 +3079,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class; friend class C10_TensorImpl_Size_Check_Dummy_Class;
}; };
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints] // Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for // Changed the size of TensorImpl? If the size went down, good for

View File

@ -11,11 +11,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
#define PANIC(m) \ #define PANIC(m) \
TORCH_INTERNAL_ASSERT( \ TORCH_INTERNAL_ASSERT( \
@ -23,10 +20,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \ "attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died") " on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override { c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach); PANIC(detach);
} }

View File

@ -18,9 +18,6 @@ namespace c10 {
struct IValue; struct IValue;
class OperatorHandle; class OperatorHandle;
struct TensorImpl; struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10 } // namespace c10
namespace torch::jit { namespace torch::jit {
@ -129,12 +126,9 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject. // Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0; virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call. // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
virtual void decref(PyObject* pyobj) const = 0; // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL. virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of // Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this // detach, which will also arrange for the PyObject to get copied in this

View File

@ -0,0 +1,56 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,58 +8,117 @@
#include <atomic> #include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl { namespace c10::impl {
struct C10_API PyObjectSlot { struct C10_API PyObjectSlot {
public: public:
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
// Query the PyObject interpreter. This may return null if there is no // Query the PyObject interpreter. This may return null if there is no
// interpreter. // interpreter. This is racy!
PyInterpreter* pyobj_interpreter() const { PyInterpreter* pyobj_interpreter();
return pyobj_interpreter_.load(std::memory_order_acquire);
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
} }
PyInterpreter& load_pyobj_interpreter() const { PyInterpreter& load_pyobj_interpreter() const;
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
PyObject* load_pyobj() const { bool owns_pyobj();
return pyobj_.load(std::memory_order_acquire);
}
void store_pyobj(PyObject* obj) { void set_owns_pyobj(bool b);
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
private: private:
// This is now always the global interpreter if the PyObject is set. // This field contains the interpreter tag for this object. See
// Maybe we can remove this field some day... // Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_; std::atomic<PyInterpreter*> pyobj_interpreter_;
// The PyObject representing this Tensor or nullptr. Ownership is managed // This field contains a reference to a PyObject representing this Tensor.
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// reference is already dead. // PyObject for it and set this field. This field does not have to be
std::atomic<PyObject*> pyobj_; // protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
friend class torch::utils::PyObjectPreservation; //
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
}; };
} // namespace c10::impl } // namespace c10::impl

View File

@ -12,10 +12,6 @@ template <typename, typename...>
class class_; class class_;
} }
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10 { namespace c10 {
class intrusive_ptr_target; class intrusive_ptr_target;
namespace raw { namespace raw {
@ -37,8 +33,6 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1; constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32); constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne); constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
// Indicates whether the object has a PyObject wrapper.
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
template <class TTarget> template <class TTarget>
struct intrusive_target_default_null_type final { struct intrusive_target_default_null_type final {
@ -61,11 +55,7 @@ inline uint32_t refcount(uint64_t combined_refcount) {
} }
inline uint32_t weakcount(uint64_t combined_refcount) { inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32); return static_cast<uint32_t>(combined_refcount >> 32);
}
inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
} }
// The only requirement for refcount increment is that it happens-before // The only requirement for refcount increment is that it happens-before
@ -76,6 +66,12 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc; return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
} }
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment( inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) { std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment( return detail::weakcount(atomic_combined_refcount_increment(
@ -103,11 +99,6 @@ inline uint32_t atomic_weakcount_decrement(
combined_refcount, kWeakReferenceCountOne)); combined_refcount, kWeakReferenceCountOne));
} }
template <class T, class = void>
struct TargetTraits {
static constexpr bool can_have_pyobject = false;
};
} // namespace detail } // namespace detail
/** /**
@ -164,23 +155,6 @@ class C10_API intrusive_ptr_target {
// we can atomically operate on both at the same time for performance // we can atomically operate on both at the same time for performance
// and defined behaviors. // and defined behaviors.
// //
// Note [PyObject preservation for Tensor and Storages]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// intrusive_ptr has special support for preserving PyObject wrappers
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
// the combined_refcount_ is used to indicate whether the object has a
// PyObject wrapper.
//
// - The PyObject, if it exists, holds a strong reference to the
// intrusive_ptr_target.
//
// - When the refcount goes from 1 to 2, we incref the PyObject.
//
// - When the refcount goes from 2 to 1, we decref the PyObject.
//
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
// are other C++ references to the intrusive_ptr_target.
mutable std::atomic<uint64_t> combined_refcount_; mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8); static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8); static_assert(alignof(std::atomic<uint64_t>) == 8);
@ -198,8 +172,6 @@ class C10_API intrusive_ptr_target {
template <typename T> template <typename T>
friend struct ExclusivelyOwnedTensorTraits; friend struct ExclusivelyOwnedTensorTraits;
friend class torch::utils::PyObjectPreservation;
protected: protected:
// protected destructor. We never want to destruct intrusive_ptr_target* // protected destructor. We never want to destruct intrusive_ptr_target*
// directly. // directly.
@ -283,16 +255,6 @@ class C10_API intrusive_ptr_target {
*/ */
virtual void release_resources() {} virtual void release_resources() {}
/**
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const { uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order)); return detail::refcount(combined_refcount_.load(order));
} }
@ -303,19 +265,6 @@ class C10_API intrusive_ptr_target {
} }
}; };
namespace detail {
#ifndef C10_MOBILE
template <>
struct TargetTraits<c10::intrusive_ptr_target> {
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
// or StorageImpl, so we have to allow for PyObject support.
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
template <class TTarget, class NullType> template <class TTarget, class NullType>
class weak_intrusive_ptr; class weak_intrusive_ptr;
@ -365,34 +314,18 @@ class intrusive_ptr final {
void retain_() { void retain_() {
if (target_ != NullType::singleton()) { if (target_ != NullType::singleton()) {
uint64_t combined = detail::atomic_combined_refcount_increment( uint32_t new_refcount =
target_->combined_refcount_, detail::kReferenceCountOne); detail::atomic_refcount_increment(target_->combined_refcount_);
uint32_t new_refcount = detail::refcount(combined);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1, new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero."); "intrusive_ptr: Cannot increase refcount after it reached zero.");
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 1 to 2, we need to incref the
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!detail::has_pyobject(combined),
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
} }
} }
void reset_() noexcept { void reset_() noexcept {
if (target_ != NullType::singleton()) { if (target_ != NullType::singleton()) {
if (is_uniquely_owned()) { if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
// Both counts are 1, so there are no weak references and // Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other // we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion // threads can observe the effects of this target_ deletion
@ -404,10 +337,9 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement( auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne); target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount); if (detail::refcount(combined_refcount) == 0) {
bool has_pyobject = detail::has_pyobject(combined_refcount); bool should_delete =
if (new_refcount == 0) { (combined_refcount == detail::kWeakReferenceCountOne);
bool should_delete = detail::weakcount(combined_refcount) == 1;
// See comment above about weakcount. As long as refcount>0, // See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references. // weakcount is one larger than the actual number of weak references.
// So we need to decrement it here. // So we need to decrement it here.
@ -424,18 +356,6 @@ class intrusive_ptr final {
if (should_delete) { if (should_delete) {
delete target_; delete target_;
} }
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
} }
} }
} }
@ -602,16 +522,6 @@ class intrusive_ptr final {
return use_count() == 1; return use_count() == 1;
} }
/**
* Stronger than unique() in that it must not have any weakrefs as well.
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/** /**
* Returns an owning (!) pointer to the underlying object and makes the * Returns an owning (!) pointer to the underlying object and makes the
* intrusive_ptr instance invalid. That means the refcount is not decreased. * intrusive_ptr instance invalid. That means the refcount is not decreased.
@ -1022,7 +932,6 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) { if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>(); return intrusive_ptr<TTarget, NullType>();
} else { } else {
bool increfed = false;
auto combined_refcount = auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed); target_->combined_refcount_.load(std::memory_order_relaxed);
do { do {
@ -1031,31 +940,12 @@ class weak_intrusive_ptr final {
// Return nullptr. // Return nullptr.
return intrusive_ptr<TTarget, NullType>(); return intrusive_ptr<TTarget, NullType>();
} }
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (detail::has_pyobject(combined_refcount) &&
detail::refcount(combined_refcount) == 1 && !increfed) {
// Object has a python wrapper with no other C++ references.
// We need to to incref the Python object before we acquire a
// strong reference to the C++ object to avoid a situation
// where the Python object is deallocated concurrently.
if (!target_->try_incref_pyobject()) {
return intrusive_ptr<TTarget, NullType>();
}
increfed = true;
}
}
} while (!target_->combined_refcount_.compare_exchange_weak( } while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount, combined_refcount,
combined_refcount + detail::kReferenceCountOne, combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire, std::memory_order_acquire,
std::memory_order_relaxed)); std::memory_order_relaxed));
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (increfed && detail::refcount(combined_refcount) != 1) {
target_->decref_pyobject();
}
}
return intrusive_ptr<TTarget, NullType>( return intrusive_ptr<TTarget, NullType>(
target_, raw::DontIncreaseRefcount{}); target_, raw::DontIncreaseRefcount{});
} }
@ -1170,18 +1060,7 @@ namespace intrusive_ptr {
// NullType::singleton to this function // NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) { inline void incref(intrusive_ptr_target* self) {
if (self) { if (self) {
uint64_t combined = detail::atomic_combined_refcount_increment( detail::atomic_refcount_increment(self->combined_refcount_);
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
#endif
} }
} }

View File

@ -1,113 +0,0 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -42,7 +42,6 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob: :glob:
:maxdepth: 1 :maxdepth: 1
device
hooks hooks
autoload autoload
operators operators

View File

@ -4,12 +4,17 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = ""); void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \ #define OPENREG_CHECK(EXPR, ...) \
do { \ do { \
const orError_t __err = EXPR; \ const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \ if (__err != orSuccess) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \ orCheckFail( \
} \ __func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0) } while (0)

View File

@ -1,4 +1,3 @@
#include <c10/util/Exception.h>
#include <include/openreg.h> #include <include/openreg.h>
#include "OpenRegException.h" #include "OpenRegException.h"
@ -10,22 +9,21 @@ orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count); return orGetDeviceCount(dev_count);
} }
orError_t GetDevice(DeviceIndex* device) { orError_t GetDevice(c10::DeviceIndex* device) {
int tmp_device = -1; int tmp_device = -1;
auto err = orGetDevice(&tmp_device); auto err = orGetDevice(&tmp_device);
*device = static_cast<DeviceIndex>(tmp_device); *device = static_cast<c10::DeviceIndex>(tmp_device);
return err; return err;
} }
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
orError_t SetDevice(DeviceIndex device) { orError_t SetDevice(c10::DeviceIndex device) {
int cur_device = -1; int cur_device = -1;
OPENREG_CHECK(orGetDevice(&cur_device)); orGetDevice(&cur_device);
if (device == cur_device) { if (device == cur_device) {
return orSuccess; return orSuccess;
} }
return orSetDevice(device); return orSetDevice(device);
} }
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
int device_count_impl() { int device_count_impl() {
int count = 0; int count = 0;
@ -33,37 +31,34 @@ int device_count_impl() {
return count; return count;
} }
OPENREG_EXPORT DeviceIndex device_count() noexcept { OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once // initialize number of devices only once
static int count = []() { static int count = []() {
try { try {
auto result = device_count_impl(); auto result = device_count_impl();
TORCH_CHECK( TORCH_CHECK(
result <= std::numeric_limits<DeviceIndex>::max(), result <= std::numeric_limits<c10::DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed"); "Too many devices, DeviceIndex overflowed");
return result; return result;
} catch (const Error& ex) { } catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning // We don't want to fail, but still log the warning
// msg() returns the message without the stack trace // msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg()); TORCH_WARN("Device initialization: ", ex.msg());
return 0; return 0;
} }
}(); }();
return static_cast<DeviceIndex>(count); return static_cast<c10::DeviceIndex>(count);
} }
OPENREG_EXPORT DeviceIndex current_device() { OPENREG_EXPORT c10::DeviceIndex current_device() {
DeviceIndex cur_device = -1; c10::DeviceIndex cur_device = -1;
OPENREG_CHECK(GetDevice(&cur_device)); GetDevice(&cur_device);
return cur_device; return cur_device;
} }
// LITERALINCLUDE START: OPENREG set_device FUNCTION OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
OPENREG_EXPORT void set_device(DeviceIndex device) { SetDevice(device);
check_device_index(device);
OPENREG_CHECK(SetDevice(device));
} }
// LITERALINCLUDE END: OPENREG set_device FUNCTION
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) { OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1; int current_device = -1;
@ -76,8 +71,4 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
return current_device; return current_device;
} }
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
check_device_index(to_device);
return ExchangeDevice(to_device);
}
} // namespace c10::openreg } // namespace c10::openreg

View File

@ -9,20 +9,10 @@
namespace c10::openreg { namespace c10::openreg {
OPENREG_EXPORT DeviceIndex device_count() noexcept; OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT DeviceIndex current_device(); OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(DeviceIndex device); OPENREG_EXPORT void set_device(c10::DeviceIndex device);
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device); OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
static inline void check_device_index(int64_t device) {
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::openreg::device_count()),
"), but got ",
static_cast<int>(device),
".");
}
} // namespace c10::openreg } // namespace c10::openreg

View File

@ -2,8 +2,6 @@
namespace c10::openreg { namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
} // namespace c10::openreg } // namespace c10::openreg

View File

@ -11,7 +11,6 @@
namespace c10::openreg { namespace c10::openreg {
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1; static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
@ -59,7 +58,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
set_device(d.index()); set_device(d.index());
} }
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
/** /**
* Set the current device to c10::Device, without checking for errors * Set the current device to c10::Device, without checking for errors

View File

@ -27,10 +27,6 @@ class TestDevice(TestCase):
self.assertEqual(torch.accelerator.current_device_index(), 1) self.assertEqual(torch.accelerator.current_device_index(), 1)
self.assertEqual(torch.accelerator.current_device_index(), device) self.assertEqual(torch.accelerator.current_device_index(), device)
def test_invalid_device_index(self):
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.set_device_index(2)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -34,21 +34,18 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
} }
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR // LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
PyObject* _setDevice(PyObject* self, PyObject* arg) { PyObject* _setDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
auto device = THPUtils_unpackDeviceIndex(arg); auto device = THPUtils_unpackLong(arg);
torch::utils::device_lazy_init(at::kPrivateUse1); torch::utils::device_lazy_init(at::kPrivateUse1);
c10::openreg::set_device(device); c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
Py_RETURN_NONE; Py_RETURN_NONE;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) { PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");

View File

@ -41,13 +41,8 @@ def current_device():
return torch_openreg._C._get_device() return torch_openreg._C._get_device()
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
def set_device(device) -> None: def set_device(device) -> None:
if device >= 0: return torch_openreg._C._set_device(device)
torch_openreg._C._set_device(device)
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
def init(): def init():

View File

@ -6,7 +6,7 @@ import torch.distributed._functional_collectives as funcol
import torch.nn as nn import torch.nn as nn
from torch.distributed.tensor import DeviceMesh, DTensor, Shard from torch.distributed.tensor import DeviceMesh, DTensor, Shard
from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import requires_nccl from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.distributed.fake_pg import FakeStore
@ -14,6 +14,9 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
c10d_functional = torch.ops.c10d_functional c10d_functional = torch.ops.c10d_functional
c10d_ops = torch.ops.c10d c10d_ops = torch.ops.c10d
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class TestCommMode(TestCase): class TestCommMode(TestCase):
@ -28,7 +31,7 @@ class TestCommMode(TestCase):
dist.init_process_group( dist.init_process_group(
backend="fake", rank=1, world_size=self.world_size, store=store backend="fake", rank=1, world_size=self.world_size, store=store
) )
self.device_type = "cuda" if torch.cuda.is_available() else "cpu" self.device_type = device_type
self.world_pg = dist.distributed_c10d._get_default_group() self.world_pg = dist.distributed_c10d._get_default_group()
def checksAssert(self, comm_mode, key, expected_value, expected_total_value): def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
@ -111,12 +114,12 @@ class TestCommMode(TestCase):
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0) self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
@requires_nccl() @requires_accelerator_dist_backend(["nccl", "xccl"])
def test_comm_mode_with_c10d(self): def test_comm_mode_with_c10d(self):
if not torch.cuda.is_available(): if not torch.accelerator.is_available():
return return
inp = torch.rand(2, 8, 16).cuda() inp = torch.rand(2, 8, 16).to(device_type)
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
comm_mode = CommDebugMode() comm_mode = CommDebugMode()

View File

@ -658,11 +658,11 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms @with_comms
def test_dtensor_device_mesh_device_conversion(self): def test_dtensor_device_mesh_device_conversion(self):
# construct a cuda device mesh # construct a gpu device mesh
mesh = self.build_device_mesh() mesh = self.build_device_mesh()
# construct from a cpu local tensor with cuda device mesh # construct from a cpu local tensor with gpu device mesh
# should automatically convert the dist tensor to cuda # should automatically convert the dist tensor to gpu
placements = [Shard(0)] placements = [Shard(0)]
local_tensor = torch.randn(3, 3) local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements) dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
@ -711,7 +711,7 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms @with_comms
def test_dtensor_2d_mesh(self): def test_dtensor_2d_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, 4) mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
# construct a cuda device mesh # construct a gpu device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor) mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works # construct a dist tensor on 2d device mesh and test if works
@ -733,7 +733,7 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms @with_comms
def test_device_mesh_nd(self): def test_device_mesh_nd(self):
# construct a cuda device mesh # construct a gpu device mesh
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor) mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works # construct a dist tensor on 3d device mesh and test if works
@ -1064,8 +1064,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
# Keep everything deterministic. # Keep everything deterministic.
torch.manual_seed(0) torch.manual_seed(0)
tensor = torch.rand(size) tensor = torch.rand(size)
if self.device_type == "cuda": if self.device_type != "cpu":
return tensor.cuda() return tensor.to(self.device_type)
else: else:
return tensor return tensor

View File

@ -39,6 +39,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel, RowwiseParallel,
) )
from torch.distributed.tensor.placement_types import _StridedShard from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_device_type import skipXPUIf
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -47,8 +48,6 @@ from torch.testing._internal.common_utils import (
run_tests, run_tests,
skipIfHpu, skipIfHpu,
skipIfTorchDynamo, skipIfTorchDynamo,
TEST_CUDA,
TEST_HPU,
) )
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,
@ -95,6 +94,10 @@ aot_eager_graph = aot_autograd(
partition_fn=min_cut_rematerialization_partition, partition_fn=min_cut_rematerialization_partition,
) )
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh): def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
""" """
@ -141,7 +144,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
@property @property
def device_type(self) -> str: def device_type(self) -> str:
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu" return device_type
@property @property
def world_size(self) -> int: def world_size(self) -> int:
@ -160,9 +163,9 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
res = fn(x) res = fn(x)
res.to_local().sum().backward() res.to_local().sum().backward()
@unittest.skipIf(not TEST_CUDA, "CUDA not available") @unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available")
def test_dtensor_basic_export(self): def test_dtensor_basic_export(self):
mesh = DeviceMesh("cuda", torch.arange(self.world_size)) mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
param = torch.randn(4, 4) param = torch.randn(4, 4)
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False) param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
@ -188,10 +191,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
) )
self.assertExpectedInline( self.assertExpectedInline(
str(ep.graph_module.code).strip(), str(ep.graph_module.code).strip(),
"""\ f"""\
def forward(self, b_buffer, x): def forward(self, b_buffer, x):
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None
view_as = torch.ops.aten.view_as.default(to, to); to = None view_as = torch.ops.aten.view_as.default(to, to); to = None
dtensor___init__0 = self.dtensor___init__0 dtensor___init__0 = self.dtensor___init__0
dtensor_const_func_spec0 = self.dtensor_const_func_spec0 dtensor_const_func_spec0 = self.dtensor_const_func_spec0
@ -206,10 +209,10 @@ def forward(self, b_buffer, x):
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
self.assertExpectedInline( self.assertExpectedInline(
str(ep.run_decompositions({}).graph_module.code).strip(), str(ep.run_decompositions({}).graph_module.code).strip(),
"""\ f"""\
def forward(self, b_parametrizations_buffer_original0, x): def forward(self, b_parametrizations_buffer_original0, x):
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
@ -377,6 +380,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
self.assertEqual(res, ref) self.assertEqual(res, ref)
@skipIfHpu @skipIfHpu
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981")
def test_dtensor_dynamic_loss_parallel_log_softmax(self): def test_dtensor_dynamic_loss_parallel_log_softmax(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -815,13 +819,13 @@ def forward(self, b_parametrizations_buffer_original0, x):
out = layer_norm.permute(0, 2, 1) out = layer_norm.permute(0, 2, 1)
return out return out
x = torch.randn(4, 2, 4, requires_grad=True, device="cuda") x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type)
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False) x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
y = torch.randn(4, requires_grad=True, device="cuda") y = torch.randn(4, requires_grad=True, device=self.device_type)
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
z = torch.randn(4, requires_grad=True, device="cuda") z = torch.randn(4, requires_grad=True, device=self.device_type)
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False) z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
@ -919,7 +923,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
# pass in tensor as inputs/outputs, create DTensor and run redistribute # pass in tensor as inputs/outputs, create DTensor and run redistribute
# (allgather collective) inside the fn # (allgather collective) inside the fn
def fn(x_dt): def fn(x_dt):
if x_dt.device_mesh.device_type == "cuda": if x_dt.device_mesh.device_type == f"{self.device_type}":
return x_dt + 1 return x_dt + 1
else: else:
return x_dt + 2 return x_dt + 2
@ -1051,7 +1055,7 @@ def forward(self, primals_1):
model = FakeTransformer().to(self.device_type) model = FakeTransformer().to(self.device_type)
tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",))
# apply sequence parallel # apply sequence parallel
parallel_plan = { parallel_plan = {

View File

@ -27,8 +27,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
run_tests, run_tests,
TEST_CUDA,
TEST_HPU,
) )
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class, create_local_tensor_test_class,
@ -541,7 +539,7 @@ class RedistributeTest(DTensorTestBase):
local_out_dt = out_dt.to_local() local_out_dt = out_dt.to_local()
local_expected_dt = expected_dt.to_local() local_expected_dt = expected_dt.to_local()
self.assertEqual(out_dt.to_local(), expected_dt.to_local()) self.assertEqual(out_dt.to_local(), expected_dt.to_local())
if TEST_HPU or TEST_CUDA: if torch.accelerator.is_available():
self.assertEqual( self.assertEqual(
comm_mode.get_comm_counts()[ comm_mode.get_comm_counts()[
torch.ops._dtensor.shard_dim_alltoall torch.ops._dtensor.shard_dim_alltoall

View File

@ -296,8 +296,8 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(dist_tensor.dtype, torch.float32) self.assertEqual(dist_tensor.dtype, torch.float32)
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16) self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
@with_comms
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@with_comms
def test_stack(self): def test_stack(self):
mesh_2d = DeviceMesh( mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2) self.device_type, torch.arange(self.world_size).reshape(2, 2)

View File

@ -952,9 +952,7 @@ User code traceback:
self.assertExpectedInline( self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\ """\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip. Graph break: skip: from user code at:
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn File "test_error_messages.py", line N, in fn
assert x is None assert x is None
""", """,
@ -1080,88 +1078,6 @@ from user code:
""", """,
) )
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback(self, records):
def fn(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
self.assertExpectedInline(
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
"""\
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message(self, records):
def fn(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
if x.sum() > 0:
""",
)
@make_logging_test(dynamo=logging.DEBUG)
def test_skip_frame_empty_function_message(self, records):
def empty_fn(x):
pass
torch.compile(empty_fn, backend="eager")(torch.randn(3))
skip_messages = [
r
for r in records
if "intentionally decided to skip the frame" in r.getMessage()
]
self.assertEqual(len(skip_messages), 1)
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
self.assertExpectedInline(
msg,
"""\
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
Reason: no content in function call empty_fn test_error_messages.py N""",
)
@make_logging_test(graph_breaks=True) @make_logging_test(graph_breaks=True)
def test_nested_compile_user_frames(self, records): def test_nested_compile_user_frames(self, records):
def fn(x): def fn(x):
@ -1708,110 +1624,6 @@ from user code:
) )
class NestedGraphBreakLoggingTests(
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
):
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
def f2(x):
return f1(x + 2)
def f3(x):
return f2(x + 3)
torch.compile(f3, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
User code traceback:
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
torch.compile(f3, backend="eager")(torch.randn(3))
File "test_error_messages.py", line N, in f3
return f2(x + 3)
File "test_error_messages.py", line N, in f2
return f1(x + 2)
File "test_error_messages.py", line N, in f1
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
def f2(x):
return f1(x + 4)
def f3(x):
return f2(x + 5)
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
User code traceback:
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
File "test_error_messages.py", line N, in f3
return f2(x + 5)
File "test_error_messages.py", line N, in f2
return f1(x + 4)
File "test_error_messages.py", line N, in f1
if x.sum() > 0:
""",
)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -14036,44 +14036,6 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
except Exception as e: except Exception as e:
self.fail(f"torch.compile failed with error: {e}") self.fail(f"torch.compile failed with error: {e}")
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorify_track_item_symint(self):
def _random_resize(image: torch.Tensor):
image_metanet = image
default_patch_size = 14
rand_cnn_resolution = (224, 256)
min_nump = rand_cnn_resolution[0] // default_patch_size
max_nump = rand_cnn_resolution[1] // default_patch_size
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
torch._check(new_nump > 0)
torch._check(new_nump * default_patch_size > 1)
image_metanet = F.interpolate(
image_metanet,
size=(new_nump * default_patch_size, new_nump * default_patch_size),
mode="bilinear",
align_corners=True,
)
img_h_new, img_w_new = image_metanet.shape[2:]
return (img_h_new, img_w_new), image_metanet
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
# Test the function
input_tensor = torch.rand(1, 3, 224, 224)
(h, w), output = _random_resize_compiled(input_tensor)
# Verify output properties
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 3)
self.assertEqual(output.shape[2], h)
self.assertEqual(output.shape[3], w)
self.assertTrue(h % 14 == 0)
self.assertTrue(w % 14 == 0)
self.assertTrue(224 <= h <= 256)
self.assertTrue(224 <= w <= 256)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -10895,34 +10895,6 @@ get_out().sum().backward()
self.assertTrue(gradcheck(func, x, fast_mode=True)) self.assertTrue(gradcheck(func, x, fast_mode=True))
def test_grad_thread_safety(self):
import threading
from concurrent.futures import ThreadPoolExecutor
NUM_ITERS = 10
NUM_THREADS = 4
# Concurrent calls to tensor.untyped_storage()
def access_grad(tensor, barrier):
barrier.wait()
return weakref.ref(tensor.grad)
for i in range(NUM_ITERS):
tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
(tensor**2).sum().backward()
barrier = threading.Barrier(NUM_THREADS)
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
futures = [
executor.submit(access_grad, tensor, barrier)
for _ in range(NUM_THREADS)
]
# Check that all the grad tensors returned were the same
for future in futures:
self.assertEqual(future.result()(), tensor.grad)
self.assertIsNotNone(tensor.grad)
def index_perm_variable(shape, max_indices): def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple): if not isinstance(shape, tuple):

View File

@ -259,8 +259,7 @@ class TestTorchDeviceType(TestCase):
def test_storage_use_count(self, device): def test_storage_use_count(self, device):
a = torch.randn(10, device=device) a = torch.randn(10, device=device)
prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata)
# Two references: 'a' and the wrapper returned by untyped_storage() self.assertEqual(prev_cf, 1)
self.assertEqual(prev_cf, 2)
b = a.view(2, 5) b = a.view(2, 5)
self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1) self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1)
@ -9325,7 +9324,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
member_var = object() member_var = object()
err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor" err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
with self.assertRaisesRegex(TypeError, err_msg): with self.assertRaisesRegex(RuntimeError, err_msg):
s0 = t0.as_subclass(BadSubTensor) s0 = t0.as_subclass(BadSubTensor)
# FIXME: Port to a test suite that better fits slicing # FIXME: Port to a test suite that better fits slicing
@ -10325,21 +10324,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_tensor_dead_weak_ref(self): def test_tensor_dead_weak_ref(self):
x = torch.ones(2) x = torch.empty(2)
w_x = weakref.ref(x) w_x = weakref.ref(x)
y = torch.ones(2) y = torch.empty(2)
y.grad = x y.grad = x
del x del x
x = w_x() x = w_x()
# x should keep the tensor live. This didn't happen in earlier PyTorch # Ideally, x would keep the tensor live. But CPython doesn't
# versions. # provide enough hooks to do this. So it will go dead and x
# will transmute into an undefined tensor. Not great, but the
# best we can do.
del y del y
self.assertEqual(2, x.sum()) self.assertRaises(RuntimeError, lambda: x.sigmoid())
del x
self.assertIsNone(w_x())
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_storage_dead_weak_ref(self): def test_storage_dead_weak_ref(self):
@ -10347,9 +10345,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
w_x = weakref.ref(x) w_x = weakref.ref(x)
y = torch.tensor(x) y = torch.tensor(x)
del x del x
self.assertIsNotNone(w_x())
x = w_x()
# Ideally, x would keep the storage live. But CPython doesn't
# provide enough hooks to do this. So it will go dead and x
# will transmute into storage with null StorageImpl. Not great, but the
# best we can do.
del y del y
self.assertIsNone(w_x())
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
def test_tensor_resurrected_weak_ref(self): def test_tensor_resurrected_weak_ref(self):
x = torch.empty(2) x = torch.empty(2)
@ -10410,31 +10415,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(called) self.assertTrue(called)
def test_storage_thread_safety(self):
import threading
from concurrent.futures import ThreadPoolExecutor
NUM_ITERS = 10
NUM_THREADS = 4
# Concurrent calls to tensor.untyped_storage()
def access_untyped_storage(tensor, barrier):
barrier.wait()
return weakref.ref(tensor.untyped_storage())
for i in range(NUM_ITERS):
tensor = torch.tensor([1.0, 2.0, 3.0])
barrier = threading.Barrier(NUM_THREADS)
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
futures = [
executor.submit(access_untyped_storage, tensor, barrier)
for _ in range(NUM_THREADS)
]
# Check that all the storages returned were the same
for future in futures:
self.assertEqual(future.result()(), tensor.untyped_storage())
# FIXME: move to test_linalg # FIXME: move to test_linalg
@torch.inference_mode() @torch.inference_mode()
def test_bmm_multithreaded(self): def test_bmm_multithreaded(self):

View File

@ -1870,7 +1870,7 @@ class ConvertFrame:
raise raise
soft_fail = isinstance(e, Unsupported) soft_fail = isinstance(e, Unsupported)
code = frame.f_code
# This is a soft failure. In the sense, the code path reaches here # This is a soft failure. In the sense, the code path reaches here
# when we do not support graph breaks on bytecodes like LOAD_ATTR, # when we do not support graph breaks on bytecodes like LOAD_ATTR,
# BUILD_SET etc. In such case, we can fallback to eager without # BUILD_SET etc. In such case, we can fallback to eager without
@ -1885,13 +1885,7 @@ class ConvertFrame:
user_stack_formatted = "".join( user_stack_formatted = "".join(
traceback.format_list(user_stack) traceback.format_list(user_stack)
) )
frame_info = exc.format_frame_info(code) user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
user_stack_trace = (
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
"The graph break occurred in the following user code:\n"
f"{user_stack_formatted}"
)
torch._logging.trace_structured( torch._logging.trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {
@ -1903,7 +1897,6 @@ class ConvertFrame:
graph_break_log.debug( graph_break_log.debug(
user_stack_trace, user_stack_trace,
exc_info=True, exc_info=True,
stack_info=config.verbose,
) )
if not config.suppress_errors and not soft_fail: if not config.suppress_errors and not soft_fail:

View File

@ -794,38 +794,6 @@ def format_error_msg_verbose(
return msg return msg
def format_frame_info(code: types.CodeType) -> str:
return (
f"{getattr(code, 'co_name', '<unknown>')} "
f"({getattr(code, 'co_filename', '<unknown>')} "
f"line {getattr(code, 'co_firstlineno', 0)})"
)
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
if code is not None:
frame_info = format_frame_info(code)
return (
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: {reason}"
)
else:
return (
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
f"Reason: {reason}"
)
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
frame_info = format_frame_info(code)
return (
"Skipping frame because there is a graph break in a for/while loop\n"
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
f"{frame_summary}"
)
def format_error_msg( def format_error_msg(
exc: Exception, exc: Exception,
code: types.CodeType, code: types.CodeType,

View File

@ -94,8 +94,6 @@ from .exc import (
BackendCompilerFailed, BackendCompilerFailed,
collapse_resume_frames, collapse_resume_frames,
format_graph_break_message, format_graph_break_message,
format_loop_skip_frame_message,
format_skip_frame_message,
get_stack_above_dynamo, get_stack_above_dynamo,
ResumePrologueTracingError, ResumePrologueTracingError,
StepUnsupported, StepUnsupported,
@ -607,9 +605,9 @@ def generic_jump(
) )
# compile a partial subgraph prefix then jump into user code # compile a partial subgraph prefix then jump into user code
if self.maybe_has_backedge(): if self.maybe_has_backedge():
msg = format_loop_skip_frame_message( msg = (
self.f_code, "Skipping frame because there is a graph break in a for/while loop\n"
"".join(traceback.format_list([self.frame_summary()])), f"{self.frame_summary()}"
) )
log.info(msg) log.info(msg)
raise exc.SkipFrame(msg) raise exc.SkipFrame(msg)
@ -885,9 +883,9 @@ def break_graph_if_unsupported(
) )
if self.maybe_has_backedge(): if self.maybe_has_backedge():
msg = format_loop_skip_frame_message( msg = (
self.f_code, "Skipping frame because there is a graph break in a for/while loop\n"
"".join(traceback.format_list([self.frame_summary()])), f"{self.frame_summary()}"
) )
log.info(msg) log.info(msg)
raise exc.SkipFrame(msg) from excp raise exc.SkipFrame(msg) from excp
@ -4628,9 +4626,8 @@ class InstructionTranslator(InstructionTranslatorBase):
and not self.error_on_graph_break and not self.error_on_graph_break
and not self.is_tracing_resume_prologue and not self.is_tracing_resume_prologue
): ):
raise exc.SkipFrame( raise exc.SkipFrame("because no content in function call")
format_skip_frame_message(self.f_code, "no content in function call")
)
self.instruction_pointer = None self.instruction_pointer = None
_step_logger()( _step_logger()(
logging.INFO, logging.INFO,

View File

@ -2248,15 +2248,12 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
try: try:
val.data_ptr() # will throw for functorch tensors val.data_ptr() # will throw for functorch tensors
except RuntimeError as e: except RuntimeError as e:
from .exc import format_skip_frame_message, SkipFrame from .exc import SkipFrame
# This will be GradTrackingTensor/BatchedTensor/etc # This will be GradTrackingTensor/BatchedTensor/etc
functorch_subclass_name = re.sub(r"\(.*", "", repr(val)) functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame( raise SkipFrame(
format_skip_frame_message( f"torch.compile cannot be run in context: {functorch_subclass_name}"
None,
f"torch.compile cannot be run in context: {functorch_subclass_name}",
)
) from e ) from e

View File

@ -42,7 +42,6 @@ from torch._guards import Source
from .. import config, graph_break_hints, polyfills, variables from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import ( from ..exc import (
format_skip_frame_message,
get_dynamo_observed_exception, get_dynamo_observed_exception,
handle_observed_exception, handle_observed_exception,
InfiniteGeneratorError, InfiniteGeneratorError,
@ -1653,13 +1652,8 @@ class SkipFunctionVariable(VariableTracker):
skip_frame_msg = kwargs.get("msg") skip_frame_msg = kwargs.get("msg")
if skip_frame_msg: if skip_frame_msg:
skip_frame_msg = skip_frame_msg.as_python_constant() skip_frame_msg = skip_frame_msg.as_python_constant()
else:
skip_frame_msg = ""
raise SkipFrame( raise SkipFrame(
format_skip_frame_message( f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
tx.f_code,
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
)
) )
elif self.value is torch._dynamo.step_unsupported: elif self.value is torch._dynamo.step_unsupported:
raise StepUnsupported raise StepUnsupported

View File

@ -536,14 +536,9 @@ class StorageWeakRefWrapper:
if self.extra_ref_check is not None and not self.extra_ref_check(): if self.extra_ref_check is not None and not self.extra_ref_check():
return False return False
# if extra_ref_check is not None we expect an additional reference
stor_count = torch._C._storage_Use_Count(self.ref.cdata) stor_count = torch._C._storage_Use_Count(self.ref.cdata)
if self.extra_ref_check is not None: return (stor_count - (self.extra_ref_check is not None)) == 0
# if extra_ref_check is not None we expect two additional references:
# - one from the Python storage object
# - one from the cached Tensor
stor_count -= 2
assert stor_count >= 0
return stor_count == 0
def __repr__(self) -> str: def __repr__(self) -> str:
if self.ref is None or self.ref.expired(): if self.ref is None or self.ref.expired():
@ -1444,15 +1439,7 @@ class CUDAGraphNode:
self_loc = self_ref() self_loc = self_ref()
if self_loc is None: if self_loc is None:
return False return False
refcount = self_loc.get_output_refcount(i) return self_loc.get_output_refcount(i) == 2
# pyrefly: ignore
if self_loc.cached_tensor_outputs[i]._use_count() > 1:
# c10::Tensor may also holds one reference count
assert refcount >= 3
return refcount == 3
else:
assert refcount >= 2
return refcount == 2
check = functools.partial(check_refcount, i=i) check = functools.partial(check_refcount, i=i)

View File

@ -891,14 +891,10 @@ class TorchLogsFormatter(logging.Formatter):
# exception handling - copied from logging.Formatter.format # exception handling - copied from logging.Formatter.format
s = record.message s = record.message
if record.exc_info: if record.exc_info:
from torch._dynamo import config
should_format_exc = config.verbose or artifact_name != "graph_breaks"
# Cache the traceback text to avoid converting it multiple times # Cache the traceback text to avoid converting it multiple times
# (it's constant anyway) # (it's constant anyway)
if should_format_exc: if not record.exc_text:
if not record.exc_text: record.exc_text = self.formatException(record.exc_info)
record.exc_text = self.formatException(record.exc_info)
if record.exc_text: if record.exc_text:
if s[-1:] != "\n": if s[-1:] != "\n":
s = s + "\n" s = s + "\n"

View File

@ -398,27 +398,36 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
// weak_use_count() adds 1 if use_count is non-zero // weak_use_count() adds 1 if use_count is non-zero
TORCH_CHECK( TORCH_CHECK(
a->cdata.weak_use_count() == 1, a->cdata->weak_use_count() == 1,
"Expected no weakrefs to t1's Tensor object but got ", "Expected no weakrefs to t1's Tensor object but got ",
a->cdata.weak_use_count() - 1); a->cdata->weak_use_count() - 1);
TORCH_CHECK( TORCH_CHECK(
b->cdata.weak_use_count() == 1, b->cdata->weak_use_count() == 1,
"Expected no weakrefs to t2's Tensor object but got ", "Expected no weakrefs to t2's Tensor object but got ",
b->cdata.weak_use_count() - 1); b->cdata->weak_use_count() - 1);
// NB: Creating local copies of *both* Tensors here ensures that they each
// hold a strong reference to their PyObject. This avoids having to fix up
// reference counts when we swap the PyObject slots below.
at::Tensor tmp_a = a->cdata;
at::Tensor tmp_b = b->cdata;
// Swap the Tensor Impl // Swap the Tensor Impl
a->cdata = tmp_b; c10::MaybeOwned<at::Tensor> tmp = a->cdata;
b->cdata = tmp_a;
// Fix up the PyObjects associated with each TensorImpl // The TensorImpls contain PyObjectSlots that have a reference to the PyObject
a->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(a_); // associated with the TensorImpl. Swap this field as well.
b->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(b_); std::optional<PyObject*> mb_obj_a =
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
std::optional<PyObject*> mb_obj_b =
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_INTERNAL_ASSERT(
mb_obj_a.has_value() && mb_obj_b.has_value(),
"Both tensors should have PyObjects tagged by the current python interpreter");
TORCH_CHECK(mb_obj_a.value() == a_);
TORCH_CHECK(mb_obj_b.value() == b_);
a->cdata = b->cdata;
b->cdata = tmp;
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_);
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_);
Py_RETURN_NONE; Py_RETURN_NONE;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS

View File

@ -45,9 +45,7 @@ struct ConcretePyInterpreterVTable final
std::string name() const override; std::string name() const override;
void incref(PyObject* pyobj) const override; void incref(PyObject* pyobj) const override;
void decref(PyObject* pyobj) const override; void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override;
size_t refcnt(PyObject* pyobj) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to // TODO: Need to make this work for StorageImpl too. I imagine I'll want to
// operate upon a PyObjectSlot rather than a TensorImpl // operate upon a PyObjectSlot rather than a TensorImpl
@ -237,13 +235,53 @@ py::object torchDispatchFromTensorImpl(
TorchFunctionName::TorchDispatch)); TorchFunctionName::TorchDispatch));
} }
void ConcretePyInterpreterVTable::decref(PyObject* pyobj) const { // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Before calling PyInterpreter::decref, we must statically know if the
// pyobj has a PyObjectSlot or not.
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
// - If it does not have a PyObjectSlot, we can freely decref
// One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here.
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
const {
// Leak the pyobj if not initialized. This can happen if we are running // Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned) // exit handlers that are destructing tensors with residual (owned)
// PyObjects stored in them. // PyObjects stored in them.
if (!Py_IsInitialized()) if (!Py_IsInitialized())
return; return;
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
// Two possibilities:
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
// Storage. Then we must be careful about PyObject resurrection (see
// THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
if (THPVariable_Check(pyobj)) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject
// so that it fails on subsequent uses. Don't raise an error here;
// you're probably in a destructor.
TORCH_WARN(
"Deallocating Tensor that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"Tensor and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this tensor via the PyObject will now fail.");
(reinterpret_cast<THPVariable*>(pyobj))->cdata =
c10::MaybeOwned<torch::autograd::Variable>();
} else if (THPStorage_Check(pyobj)) {
TORCH_WARN(
"Deallocating UntypedStorage that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this storage via the PyObject will now fail.");
(reinterpret_cast<THPStorage*>(pyobj))->cdata =
c10::MaybeOwned<c10::Storage>();
}
}
Py_DECREF(pyobj); Py_DECREF(pyobj);
} }
@ -254,25 +292,6 @@ void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
Py_INCREF(pyobj); Py_INCREF(pyobj);
} }
bool ConcretePyInterpreterVTable::try_incref(
const c10::impl::PyObjectSlot& pyobj_slot) const {
if (!Py_IsInitialized())
return false;
pybind11::gil_scoped_acquire gil;
PyObject* pyobj = pyobj_slot.load_pyobj();
if (!pyobj) {
return false;
}
return PyUnstable_TryIncRef(pyobj);
}
size_t ConcretePyInterpreterVTable::refcnt(PyObject* pyobj) const {
if (!Py_IsInitialized() || pyobj == nullptr)
return 0;
pybind11::gil_scoped_acquire gil;
return Py_REFCNT(pyobj);
}
bool isPythonTensor(const at::Tensor& tensor) { bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
} }
@ -601,7 +620,11 @@ static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor, const c10::TensorImpl* tensor,
py::capsule& capsule, py::capsule& capsule,
const char* attr_name) { const char* attr_name) {
PyObject* obj = tensor->pyobj_slot()->load_pyobj(); std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto obj = mb_obj.value();
py::handle(obj).attr(attr_name) = capsule; py::handle(obj).attr(attr_name) = capsule;
} }
@ -625,7 +648,11 @@ static c10::ArrayRef<T> get_set_cached_attr(
const c10::TensorImpl* tensor, const c10::TensorImpl* tensor,
const char* base_attr_name, const char* base_attr_name,
const py::object& obj) { const py::object& obj) {
PyObject* tensor_obj = tensor->pyobj_slot()->load_pyobj(); std::optional<PyObject*> mb_obj =
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto tensor_obj = mb_obj.value();
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len"); auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
bool is_buffer_allocated = false; bool is_buffer_allocated = false;

View File

@ -23,8 +23,6 @@
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
#include <fmt/format.h> #include <fmt/format.h>
using torch::utils::PyObjectPreservation;
template <> template <>
void THPPointer<c10::StorageImpl>::free() { void THPPointer<c10::StorageImpl>::free() {
if (ptr) { if (ptr) {
@ -34,72 +32,238 @@ void THPPointer<c10::StorageImpl>::free() {
PyTypeObject* THPStorageClass = nullptr; PyTypeObject* THPStorageClass = nullptr;
// Create a new Python Storage object, but don't set the pyobj slot on the PyObject* THPStorage_NewWithStorage(
// c10::Storage object. PyTypeObject* type,
static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) { c10::Storage _storage,
bool allow_preexisting_pyobj) {
TORCH_CHECK(
PyType_IsSubtype(type, &THPStorageType),
"Creating a Storage subclass from a class that does not inherit from ",
"Storage is not possible. Make sure your class inherits from Storage.");
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
TORCH_CHECK(
allow_preexisting_pyobj,
"Creating a new Storage subclass ",
type->tp_name,
" but the raw Storage object is already associated to a python object ",
"of type ",
maybe_pyobj.value()->ob_type->tp_name);
PyObject* obj = *maybe_pyobj;
PyTypeObject* obj_type = Py_TYPE(obj);
TORCH_CHECK(
obj_type == type || PyType_IsSubtype(obj_type, type),
"Creating a new Storage subclass ",
type->tp_name,
" but the raw Storage object is already associated to a python object ",
"of type ",
maybe_pyobj.value()->ob_type->tp_name,
" which is not a subclass of the "
"requested type");
return THPStorage_Wrap(std::move(_storage));
}
PyObject* obj = type->tp_alloc(type, 0); PyObject* obj = type->tp_alloc(type, 0);
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in auto s = reinterpret_cast<THPStorage*>(obj);
// free-threaded Python.
PyUnstable_EnableTryIncRef(obj);
auto s = (THPStorage*)obj; new (&s->cdata) c10::MaybeOwned<c10::Storage>();
new (&s->cdata) c10::Storage(std::move(_storage));
return obj;
}
// Create a new Python Storage object for a new c10::Storage, and set the s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
// pyobj slot. The c10::Storage must not already have a pyobj set.
PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) {
TORCH_CHECK(
type == THPStorageClass || PyType_IsSubtype(type, &THPStorageType),
"Creating a Storage subclass from a class that does not inherit from ",
"Storage is not possible. Make sure your class inherits from Storage.");
TORCH_INTERNAL_ASSERT(_storage.use_count() == 1);
c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl(); if (!c10::impl::HermeticPyObjectTLS::get_state()) {
PyObject* obj = THPStorage_New(type, std::move(_storage)); s->is_hermetic = false;
PyObjectPreservation::init_fresh_nonatomic( const auto& storage = THPStorage_Unpack(s);
storage_impl, storage_impl->pyobj_slot(), obj); storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj);
return obj; } else {
} s->is_hermetic = true;
// Returns a PyObject wrapper for the c10::Storage object. The existing
// wrapper is returned if it already exists.
PyObject* THPStorage_Wrap(c10::Storage storage) {
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPStorage_New(THPStorageClass, std::move(storage));
} }
return obj;
}
// Wraps the c10::Storage with a storage PyObject
PyObject* THPStorage_Wrap(c10::Storage storage) {
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
}
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
PyObject* obj = pyobj_slot->load_pyobj(); std::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
if (obj) { /*ignore_hermetic_tls=*/false);
return Py_NewRef(obj); if (maybe_pyobj.has_value()) {
} auto obj = *maybe_pyobj;
if (obj) {
TORCH_CHECK(
THPStorage_Check(obj),
"Expected a storage type, but got ",
Py_TYPE(obj)->tp_name);
obj = THPStorage_New(THPStorageClass, std::move(storage)); if (pyobj_slot->owns_pyobj()) {
PyObject* wrapper = pyobj_slot->set_owns_pyobj(false);
PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj); reinterpret_cast<THPStorage*>(obj)->cdata =
if (wrapper != obj) { c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
// Another thread beat us to it return obj;
Py_DECREF(obj); } else {
return Py_NewRef(wrapper); Py_INCREF(obj);
return obj;
}
}
} }
return obj; return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
} }
static void THPStorage_dealloc(PyObject* self) { static bool THPStorage_isPreservable(THPStorage* self) {
THPStorage* _self = reinterpret_cast<THPStorage*>(self); if (self->cdata.unsafeIsBorrowed()) {
auto pyobj_slot = _self->cdata.unsafeGetStorageImpl()->pyobj_slot(); return false;
if (pyobj_slot->load_pyobj() == self) {
TORCH_INTERNAL_ASSERT(_self->cdata.use_count() == 1);
pyobj_slot->clear();
} }
_self->cdata.~Storage(); auto const& storage = THPStorage_Unpack(self);
if (self->is_hermetic) {
return false;
}
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/true) != reinterpret_cast<PyObject*>(self)) {
return false;
}
if (storage.use_count() <= 1) {
return false;
}
return true;
}
static bool THPStorage_tryPreserve(THPStorage* self) {
if (!THPStorage_isPreservable(self)) {
return false;
}
const auto& storage = THPStorage_Unpack(self);
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/true);
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
// that we should have already set PyObjectSlot when the storage PyObject
// was created.
TORCH_INTERNAL_ASSERT(
maybe_pyobj.has_value(),
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
PyObject* pyobj = *maybe_pyobj;
TORCH_CHECK(
THPStorage_Check(pyobj),
"Expected a storage type, but got ",
Py_TYPE(pyobj)->tp_name);
TORCH_INTERNAL_ASSERT(
(void*)pyobj == (void*)self,
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
storage_impl->pyobj_slot()->set_owns_pyobj(true);
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference(reinterpret_cast<PyObject*>(self));
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
return true;
}
static void THPStorage_subclass_dealloc(PyObject* self) {
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
if (THPStorage_tryPreserve(_self)) {
return;
}
// Some subclass of StorageBase could be GC-tracked objects even
// though the base class is not
auto* type = Py_TYPE(self);
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
PyObject_GC_UnTrack(self);
}
bool has_finalizer = type->tp_finalize || type->tp_del;
if (type->tp_finalize) {
PyObject_GC_Track(self);
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
// The finalizer has resurrected the PyObject and there is a new Python
// reference to it, so we can just stop deallocating. Read about
// resurrection from `__del__` here:
// https://docs.python.org/3/reference/datamodel.html#object.__del__
return;
}
PyObject_GC_UnTrack(self);
}
// base test is unnecessary as THPStorae does not set this
if (type->tp_weaklistoffset) {
PyObject_ClearWeakRefs(self);
}
if (type->tp_del) {
PyObject_GC_Track(self);
type->tp_del(self);
if (Py_REFCNT(self) > 0) {
// Resurrected (see above comment about resurrection from `__del__`)
return;
}
PyObject_GC_UnTrack(self);
}
if (has_finalizer) {
/* New weakrefs could be created during the finalizer call.
If this occurs, clear them out without calling their
finalizers since they might rely on part of the object
being finalized that has already been destroyed. */
if (type->tp_weaklistoffset) {
/* Modeled after GET_WEAKREFS_LISTPTR() */
PyWeakReference** list = reinterpret_cast<PyWeakReference**>(
PyObject_GET_WEAKREFS_LISTPTR(self));
while (*list)
_PyWeakref_ClearRef(*list);
}
}
// Clear slots
{
PyTypeObject* base = type;
while (base != &THPStorageType) {
if (Py_SIZE(base)) {
clear_slots(base, self);
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// Clear __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr != nullptr) {
PyObject* dict = *dictptr;
if (dict != nullptr) {
Py_DECREF(dict);
*dictptr = nullptr;
}
}
}
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
_self->cdata.~MaybeOwned<c10::Storage>();
Py_TYPE(_self)->tp_free(self); Py_TYPE(_self)->tp_free(self);
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_DECREF(type);
} }
static PyObject* THPStorage_pynew( static PyObject* THPStorage_pynew(
@ -389,13 +553,64 @@ static PyMappingMethods THPStorage_mappingmethods = {
reinterpret_cast<binaryfunc>(THPStorage_get), reinterpret_cast<binaryfunc>(THPStorage_get),
reinterpret_cast<objobjargproc>(THPStorage_set)}; reinterpret_cast<objobjargproc>(THPStorage_set)};
struct THPStorageMeta {
PyHeapTypeObject base;
};
static int THPStorageMetaType_init(
PyObject* cls,
PyObject* args,
PyObject* kwargs);
static PyTypeObject THPStorageMetaType = {
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
"torch._C._StorageMeta", /* tp_name */
sizeof(THPStorageMeta), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
// NOLINTNEXTLINE(misc-redundant-expression)
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
THPStorageMetaType_init, /* tp_init */
nullptr, /* tp_alloc */
nullptr, /* tp_new */
};
// TODO: implement equality // TODO: implement equality
PyTypeObject THPStorageType = { PyTypeObject THPStorageType = {
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0) PyVarObject_HEAD_INIT(&THPStorageMetaType, 0)
"torch._C.StorageBase", /* tp_name */ "torch._C.StorageBase", /* tp_name */
sizeof(THPStorage), /* tp_basicsize */ sizeof(THPStorage), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
THPStorage_dealloc, /* tp_dealloc */ nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */ 0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */ nullptr, /* tp_getattr */
nullptr, /* tp_setattr */ nullptr, /* tp_setattr */
@ -434,6 +649,15 @@ PyTypeObject THPStorageType = {
THPStorage_pynew, /* tp_new */ THPStorage_pynew, /* tp_new */
}; };
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1;
}
(reinterpret_cast<PyTypeObject*>(cls))->tp_dealloc =
static_cast<destructor>(THPStorage_subclass_dealloc);
return 0;
}
static PyObject* THPStorage_device(THPStorage* self, void* unused) { static PyObject* THPStorage_device(THPStorage* self, void* unused) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self); THPStorage_assertNotNull(self);
@ -468,6 +692,13 @@ bool THPStorage_init(PyObject* module) {
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods()); THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods()); THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
THPStorageMetaType.tp_base = &PyType_Type;
if (PyType_Ready(&THPStorageMetaType) < 0)
return false;
Py_INCREF(&THPStorageMetaType);
PyModule_AddObject(
module, "_StorageMeta", reinterpret_cast<PyObject*>(&THPStorageMetaType));
THPStorageType.tp_methods = methods.data(); THPStorageType.tp_methods = methods.data();
THPStorageType.tp_getset = THPStorage_properties; THPStorageType.tp_getset = THPStorage_properties;
if (PyType_Ready(&THPStorageType) < 0) if (PyType_Ready(&THPStorageType) < 0)

View File

@ -11,13 +11,15 @@
struct THPStorage { struct THPStorage {
PyObject_HEAD PyObject_HEAD
c10::Storage cdata; c10::MaybeOwned<c10::Storage> cdata;
bool is_hermetic;
}; };
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
PyTypeObject* type, PyTypeObject* type,
c10::Storage _storage); c10::Storage _storage,
bool allow_preexisting_pyobj = false);
TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; TORCH_PYTHON_API extern PyTypeObject* THPStorageClass;
inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
@ -47,7 +49,7 @@ TORCH_PYTHON_API void THPStorage_assertNotNull(PyObject* obj);
TORCH_PYTHON_API extern PyTypeObject THPStorageType; TORCH_PYTHON_API extern PyTypeObject THPStorageType;
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) { inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
return storage->cdata; return *storage->cdata;
} }
inline const c10::Storage& THPStorage_Unpack(PyObject* obj) { inline const c10::Storage& THPStorage_Unpack(PyObject* obj) {

View File

@ -529,8 +529,9 @@ static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
THPUtils_typename(new_cdata)); THPUtils_typename(new_cdata));
c10::StorageImpl* ptr = c10::StorageImpl* ptr =
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata)); static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
self->cdata = self->cdata.~MaybeOwned<c10::Storage>();
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)); self->cdata = c10::MaybeOwned<c10::Storage>::owned(
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)));
Py_INCREF(self); Py_INCREF(self);
return reinterpret_cast<PyObject*>(self); return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS

View File

@ -180,9 +180,7 @@ struct TORCH_API AccumulateGrad : public Node {
if (!GradMode::is_enabled() && !new_grad.is_sparse() && if (!GradMode::is_enabled() && !new_grad.is_sparse() &&
!new_grad.is_sparse_csr() && !new_grad.is_sparse_csr() &&
!(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) && !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) &&
impl::is_tensor_stealable( at::caching::adjusted_use_count(new_grad) <= num_expected_refs &&
new_grad,
num_expected_refs + at::caching::is_cached_tensor(new_grad)) &&
(new_grad.is_mkldnn() || (new_grad.is_mkldnn() ||
utils::obeys_layout_contract(new_grad, variable))) { utils::obeys_layout_contract(new_grad, variable))) {
// See Case 1.1: Stealable dense new_grad // See Case 1.1: Stealable dense new_grad
@ -195,7 +193,7 @@ struct TORCH_API AccumulateGrad : public Node {
// SparseTensor should be the only one holding a reference to these. // SparseTensor should be the only one holding a reference to these.
new_grad._indices().use_count() <= 1 && new_grad._indices().use_count() <= 1 &&
new_grad._values().use_count() <= 1 && new_grad._values().use_count() <= 1 &&
impl::is_tensor_stealable(new_grad, num_expected_refs)) { new_grad.use_count() <= num_expected_refs) {
// Case 1.2: Stealable sparse new_grad // Case 1.2: Stealable sparse new_grad
// No scenario where we expect this to be true currently // No scenario where we expect this to be true currently
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(

View File

@ -86,8 +86,8 @@ bool can_accumulate_inplace(const Variable& v) {
v.is_non_overlapping_and_dense() && v.is_non_overlapping_and_dense() &&
// and we hold the last reference // and we hold the last reference
impl::is_tensor_stealable(v, 1 + at::caching::is_cached_tensor(v)) && at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
v.has_storage() && v.storage().use_count() == 1); v.storage().use_count() == 1);
} }
} // anonymous namespace } // anonymous namespace

View File

@ -54,7 +54,6 @@
using namespace at; using namespace at;
using namespace torch; using namespace torch;
using namespace torch::autograd; using namespace torch::autograd;
using torch::utils::PyObjectPreservation;
namespace { namespace {
class OperatorArgsKwargsView { class OperatorArgsKwargsView {
@ -322,15 +321,20 @@ PyObject* THPVariableClass = nullptr;
PyObject* ParameterClass = nullptr; PyObject* ParameterClass = nullptr;
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
const at::TensorBase& _var,
bool allow_preexisting_pyobj = false,
std::optional<bool> has_torch_dispatch_if_known = std::nullopt);
// clang-tidy gets confused by static const // clang-tidy gets confused by static const
static constexpr const char* VOLATILE_WARNING = static constexpr const char* VOLATILE_WARNING =
"volatile was removed and now has no effect. Use " "volatile was removed and now has no effect. Use "
"`with torch.no_grad():` instead."; "`with torch.no_grad():` instead.";
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls);
static bool check_has_torch_dispatch(PyObject* obj) { static bool check_has_torch_dispatch(PyObject* obj) {
if (THPVariable_CheckExact(obj)) { PyTypeObject* tp = Py_TYPE(obj);
if (THPVariable_CheckTypeExact(tp)) {
return false; return false;
} }
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__"); py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
@ -366,86 +370,152 @@ void activateGPUTrace() {
c10::impl::GPUTrace::set_trace(getPyInterpreter()); c10::impl::GPUTrace::set_trace(getPyInterpreter());
} }
static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) { PyObject* THPVariable_Wrap(const at::TensorBase& var) {
TORCH_CHECK(
PyObject_TypeCheck(obj, type),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
Py_TYPE(obj)->tp_name,
" which is not a subclass of the requested type");
}
// Generic for const Tensor& or Tensor&&
template <typename T>
static PyObject* THPVariable_WrapWithType(
T&& var,
std::optional<PyTypeObject*> desired_type) {
if (!var.defined()) { if (!var.defined()) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) {
c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot(); return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
}
PyObject* obj = pyobj_slot->load_pyobj(); std::optional<PyObject*> mb_obj =
if (obj) { var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
if (desired_type) { /*ignore_hermetic_tls=*/false);
check_tensor_subclass(obj, *desired_type); if (mb_obj.has_value()) {
auto obj = *mb_obj;
if (obj) {
if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
// C++ owns the Python object; this implies there weren't any other
// owning references to the Python object. Since we're making the
// object "live" again on Python side, let's flip back the ownership
// (Python owns C++) as it would now be unsound to deallocate the C++
// object if all C++ references go to zero
var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
reinterpret_cast<THPVariable*>(obj)->cdata =
MaybeOwned<Variable>::owned(Variable(var));
// NB: incref is not necessary, because we are "stealing" the previous
// ownership from the Variable to return it here for the wrap
return obj;
}
Py_INCREF(obj);
return obj;
} }
return Py_NewRef(obj); // TODO: a better invariant is that if we tagged, we MUST have a valid
// PyObject. That's PyObject preservation
// (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR
// being a thing, the PyObject field will get cleared when all references
// to the Python object are removed.
} }
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPVariableClass); if (C10_LIKELY(var.device().type() != c10::kXLA)) {
if (desired_type) { return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
type = *desired_type;
} else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) {
if (auto clazz = getPythonTensorClass(var.device())) {
type = reinterpret_cast<PyTypeObject*>(clazz);
}
} }
obj = type->tp_alloc(type, 0); if (auto clazz = getPythonTensorClass(var.device())) {
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); return THPVariable_NewWithVar((PyTypeObject*)clazz, var);
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
// free-threaded Python.
PyUnstable_EnableTryIncRef(obj);
auto v = reinterpret_cast<THPVariable*>(obj);
new (&v->cdata) Tensor(std::forward<T>(var));
if (THPVariable_Unpack(obj).is_uniquely_owned()) {
// We can use a faster non-atomic code path if we have the only reference to
// a fresh Tensor.
PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj);
return obj;
} }
PyObject* wrapper = return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj);
if (wrapper != obj) {
// Another thread beat us to it
Py_DECREF(obj);
if (desired_type) {
check_tensor_subclass(wrapper, *desired_type);
}
return Py_NewRef(wrapper);
}
return obj;
} }
PyObject* THPVariable_Wrap(at::TensorBase&& var) { static bool isResurrectable(THPVariable* self) {
return THPVariable_WrapWithType(std::move(var), std::nullopt); // We want to divide this check into 2 cases.
// 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is
// true). You might think that in this case, it is impossible for tp_clear to
// be called: surely the C++ reference to the PyObject is keeping it live? And
// you'd be right! In fact, when C++ owns the PyObject, we have an invariant
// that the refcount on the PyObject should be precisely one (because if you
// take out another reference to the PyObject, we're supposed to flip the
// ownership pointer back). In reality, you can violate this invariant
// temporarily with weak references, so we don't test for it in asserts.
// 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is
// false). In this case, tp_clear can get called if the PyObject is referenced
// from a dead cycle, and nowhere else. But if resurrection did not occur,
// then the reference to C++ from the PyObject must be the ONLY reference to
// the C++ object.
if (self->cdata.unsafeIsBorrowed()) {
return false;
}
auto const& tensor = THPVariable_Unpack(self);
if (!tensor.defined() || tensor.use_count() <= 1) {
return false;
}
// Check if this is hermetic. If it is, no resurrection.
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false) != (PyObject*)self) {
return false;
}
return true;
} }
PyObject* THPVariable_Wrap(const at::TensorBase& var) { // returns true if successfully rezzed; if so, cancel the
return THPVariable_WrapWithType(var, std::nullopt); // rest of deallocation
static bool THPVariable_tryResurrect(THPVariable* self) {
const auto& tensor = THPVariable_Unpack(self);
if (!isResurrectable(self)) {
return false;
}
// At this point, we are definitely going to resurrect the tensor. So, the
// tensor better be defined :)
TORCH_INTERNAL_ASSERT(tensor.defined());
// There are other C++ owners of the tensor. Flip ownership
// so that C++ owns this Python object, and cancel deallocation.
TORCH_INTERNAL_ASSERT(
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_INTERNAL_ASSERT(
maybe_pyobj.has_value(),
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
// Resurrect the Python object. This is something CPython does
// internally occasionally, see
// https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259
// so we just copy the pattern here. Note that we don't have to worry
// about saving and restoring the refcount (as the quoted code does)
// because we actually DO need to reset the refcount to one here, we
// can't assume that some other code has taken care of it.
// NB: this will overreport _Py_RefTotal but based on inspection of object.c
// there is no way to avoid this
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference((PyObject*)self);
// Flip THPVariable to be non-owning
// (near use-after-free miss here: fresh MaybeOwned is created breaking
// reference on Tensor in struct BEFORE we overwrite the old one)
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
// decrefed it.) At this point, it is probably waiting on the GIL to
// deallocate the Python object and will kill self, BUT NOT YET.
return true;
} }
PyObject* THPVariable_Wrap(const at::TensorBase& var, PyTypeObject* type) { static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
return THPVariable_WrapWithType(var, type); TORCH_INTERNAL_ASSERT(
false, "TensorBase tp_traverse function was not overridden properly");
return 0;
}
static int THPFake_clear(THPVariable* self) {
TORCH_INTERNAL_ASSERT(
false, "TensorBase tp_clear function was not overridden properly");
return 0;
} }
static PyObject* THPVariable_pynew( static PyObject* THPVariable_pynew(
@ -607,16 +677,16 @@ static PyObject* THPVariable_as_subclass(
ParsedArgs<1> parsed_args{}; ParsedArgs<1> parsed_args{};
auto r = parser.parse(_self, args, kwargs, parsed_args); auto r = parser.parse(_self, args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0); PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls); TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
// guard completely turns off torch dispatch modes, doesn't just pop off the // guard completely turns off torch dispatch modes, doesn't just pop off the
// stack // stack
torch_dispatch_mode::StashTorchDispatchStackGuard td_g; torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
c10::impl::DisablePythonDispatcher dpd_g; c10::impl::DisablePythonDispatcher dpd_g;
PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls); return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias());
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -631,7 +701,11 @@ static PyObject* THPVariable_make_subclass(
ParsedArgs<7> parsed_args{}; ParsedArgs<7> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args); auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0); PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls); TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
// guard completely turns off torch dispatch modes, doesn't just pop off the // guard completely turns off torch dispatch modes, doesn't just pop off the
// stack // stack
torch_dispatch_mode::StashTorchDispatchStackGuard td_g; torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
@ -664,11 +738,7 @@ static PyObject* THPVariable_make_subclass(
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
} }
PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls); return THPVariable_NewWithVar((PyTypeObject*)cls, data);
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -765,7 +835,11 @@ static PyObject* THPVariable_make_wrapper_subclass(
auto r = parser.parse(args, kwargs, parsed_args); auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0); PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls); TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
// This is an important safety check; without it, the default behavior will be // This is an important safety check; without it, the default behavior will be
// to continue on to the underlying CPU/CUDA kernel advertised by the dispatch // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
@ -803,8 +877,6 @@ static PyObject* THPVariable_make_wrapper_subclass(
/*storage_size=*/r.toSymIntOptional(14), /*storage_size=*/r.toSymIntOptional(14),
r.toDispatchKeySetOptional(13)); r.toDispatchKeySetOptional(13));
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
const auto sizes_strides_policy = r.stringViewOptional(10); const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) { if (sizes_strides_policy.has_value()) {
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
@ -820,7 +892,13 @@ static PyObject* THPVariable_make_wrapper_subclass(
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
} }
return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls); return THPVariable_NewWithVar(
(PyTypeObject*)cls,
tensor,
// false is the default
/*allow_preexisting_pyobj=*/false,
// we checked __torch_dispatch__ above; avoid checking again.
/*has_torch_dispatch_if_known=*/true);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -1621,7 +1699,11 @@ static PyObject* THPVariable_dtensor_new(
auto r = parser.parse(args, kwargs, parsed_args); auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0); PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls); TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
#ifndef NDEBUG #ifndef NDEBUG
// This is specifically for making a DTensor, which we know defines // This is specifically for making a DTensor, which we know defines
@ -1674,9 +1756,14 @@ static PyObject* THPVariable_dtensor_new(
/*storage_size=*/std::nullopt, /*storage_size=*/std::nullopt,
extra_dispatch_keys); extra_dispatch_keys);
tensor.set_requires_grad(requires_grad); tensor.set_requires_grad(requires_grad);
tensor.unsafeGetTensorImpl()->set_python_dispatch(true); py::object py_tensor =
py::object py_tensor = py::reinterpret_steal<py::object>( py::reinterpret_steal<py::object>(THPVariable_NewWithVar(
THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls)); (PyTypeObject*)cls,
tensor,
// false is the default
/*allow_preexisting_pyobj=*/false,
// we know DTensor has __torch_dispatch__; avoid checking again.
/*has_torch_dispatch_if_known=*/true));
py_tensor.attr(dtensor_interned_strings._spec) = spec; py_tensor.attr(dtensor_interned_strings._spec) = spec;
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor; py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
return py_tensor.release().ptr(); return py_tensor.release().ptr();
@ -3353,16 +3440,15 @@ static PyTypeObject THPVariableMetaType = {
nullptr, /* tp_new */ nullptr, /* tp_new */
}; };
static void THPVariable_dealloc(PyObject* self);
static int THPVariable_clear(THPVariable* self);
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg);
static PyTypeObject THPVariableType = { static PyTypeObject THPVariableType = {
PyVarObject_HEAD_INIT(&THPVariableMetaType, 0) PyVarObject_HEAD_INIT(&THPVariableMetaType, 0)
"torch._C.TensorBase", /* tp_name */ "torch._C.TensorBase", /* tp_name */
sizeof(THPVariable), /* tp_basicsize */ sizeof(THPVariable), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
THPVariable_dealloc, /* tp_dealloc */ // This is unspecified, because it is illegal to create a THPVariableType
// directly. Subclasses will have their tp_dealloc set appropriately
// by the metaclass
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */ 0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */ nullptr, /* tp_getattr */
nullptr, /* tp_setattr */ nullptr, /* tp_setattr */
@ -3381,8 +3467,9 @@ static PyTypeObject THPVariableType = {
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_GC, /* tp_flags */ Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */ nullptr, /* tp_doc */
(traverseproc)THPVariable_traverse, /* tp_traverse */ // Also set by metaclass
(inquiry)THPVariable_clear, /* tp_clear */ (traverseproc)THPFake_traverse, /* tp_traverse */
(inquiry)THPFake_clear, /* tp_clear */
nullptr, /* tp_richcompare */ nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
nullptr, /* tp_iter */ nullptr, /* tp_iter */
@ -3411,68 +3498,345 @@ PyObject* THPVariable_pynew(
type != &THPVariableType, type != &THPVariableType,
"Cannot directly construct TensorBase; subclass it and then construct that"); "Cannot directly construct TensorBase; subclass it and then construct that");
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
// given a raw pointer that will refcount bump // given a raw pointer that will refcount bump
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g., // NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
// alias(), lift_fresh()) which can return Tensor subclasses. We allow // alias(), lift_fresh()) which can return Tensor subclasses. We allow
// these to be passed on directly. // these to be passed on directly.
PyObject* obj = THPVariable_WrapWithType( return THPVariable_NewWithVar(
torch::utils::base_tensor_ctor(args, kwargs), type); type,
if (check_has_torch_dispatch(obj)) { tensor,
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); /*allow_preexisting_pyobj=*/true);
}
return obj;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static int THPVariable_clear(THPVariable* self) { static int THPVariable_subclass_clear(THPVariable* self) {
// Is it OK for an object to still be live after running
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
// that an object will dealloc after it's cleared. The source code explicitly
// handles this case:
// https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025
// Note that we don't need to actually resurrect here. There are 2 cases:
// 1. The PyObject is not part of a reference cycle. In this case, we don't
// need to do anything. The GC will move on to try and break the reference
// cycle on another object, which will eventually trigger tp_dealloc (and thus
// resurrection).
// 2. The PyObject is part of a reference cycle. This case should not actually
// be possible, due to the logic in our tp_traverse
// (THPVariable_subclass_traverse).
// In fact, resurrecting here breaks the invariant that "C++ owns Python only
// when PyObject's refcount would otherwise be 0". Most immediately, as we're
// merely breaking reference cycles here, there can be other references to the
// PyObject. *However*, if other objects in the refcycle resurrect, then we
// will be in a state where the PyObject has multiple Python references, yet
// C++ owns the PyObject.
// See https://github.com/pytorch/pytorch/pull/75933 for more discussion.
if (isResurrectable(self)) {
return 0;
}
// First clear Tensor specific things // First clear Tensor specific things
Py_CLEAR(self->backward_hooks); Py_CLEAR(self->backward_hooks);
Py_CLEAR(self->post_accumulate_grad_hooks); Py_CLEAR(self->post_accumulate_grad_hooks);
if (self->cdata.defined()) { const auto& tensor = THPVariable_Unpack(self);
auto pyobj_slot = self->cdata.unsafeGetTensorImpl()->pyobj_slot(); if (tensor.defined()) {
// Typically the Tensor's pyobj_slot points back to this object. The only // Two situations to consider:
// time that's not the case is if we had a race in THPVariable_Wrap and we // PyObject -owns-> Tensor
// need to discard the Python object because some other thread beat us to // unsafeIsBorrowed() is FALSE. We're obligated to look through
// setting the pyobj_slot. // Tensor to break references. Clearing cdata must induce the
if (pyobj_slot->load_pyobj() == (PyObject*)self) { // destruction of the C++ Tensor. If there were other references
// A Tensor's Python object should only be destroyed when the Tensor has // to C++ tensor, the Python object would have been resurrected
// no other references too. // by flipping the ownership.
TORCH_INTERNAL_ASSERT(self->cdata.use_count() == 1); // Tensor -owns-> PyObject
// unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
// because Tensor asked us to (it's already destructing).
// Clear the pyobj_slot so that a try_incref() call from if (!self->cdata.unsafeIsBorrowed() &&
// weak_intrusive_ptr::lock() won't see a freed pointer. tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
pyobj_slot->clear(); /*ignore_hermetic_tls=*/false) == (PyObject*)self) {
// TODO: empirically, on OS X this assert appears to be untrue
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
// distributed/rpc/test_process_group_agent.py
//
// libc++abi.dylib: terminating with uncaught exception of type
// c10::Error:
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
// please report a bug to PyTorch. Exception raised from
// THPVariable_subclass_clear at
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
// first): frame #0: c10::Error::Error(c10::SourceLocation,
// std::__1::basic_string<char, std::__1::char_traits<char>,
// std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame
// #1: c10::detail::torchCheckFail(char const*, char const*, unsigned
// int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2:
// c10::detail::torchInternalAssertFail(char const*, char const*,
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
// libtorch_python.dylib) frame #4:
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
// libtorch_python.dylib) frame #5: (anonymous
// namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
// _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6:
// c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in
// libc10.dylib) frame #7:
// c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&)
// + 91 (0x11488c11b in libtorch_python.dylib) frame #8:
// THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in
// libtorch_python.dylib) <omitting python frames> frame #47: start + 1
// (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???)
// TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
if (auto grad_acc =
torch::autograd::impl::try_get_grad_accumulator(tensor)) {
grad_acc->pre_hooks().clear();
grad_acc->tensor_pre_hooks().clear();
grad_acc->retains_grad_hooks().clear();
}
} }
} }
TORCH_INTERNAL_ASSERT(!isResurrectable(self));
{ {
// MapAllocator can take significant time to release large tensors; // MapAllocator can take significant time to release large tensors;
// release the GIL here to avoid impacting main thread perf. // release the GIL here to avoid impacting main thread perf.
pybind11::gil_scoped_release no_gil; pybind11::gil_scoped_release no_gil;
self->cdata = Variable(); self->cdata = MaybeOwned<Variable>();
} }
// Since we override the basic subtype_clear from CPython, we need a crappy
// version here just like for traverse and dealloc
// Clear all slots until we get to the base Tensor class
PyTypeObject* type = Py_TYPE((PyObject*)self);
PyTypeObject* base = type;
while (base != &THPVariableType) {
if (Py_SIZE(base))
clear_slots(base, (PyObject*)self);
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
// Assume we never have managed dict for Tensors as we don't set the flag on
// the base class
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr((PyObject*)self);
if (dictptr && *dictptr)
Py_CLEAR(*dictptr);
}
return 0; return 0;
} }
static void THPVariable_dealloc(PyObject* self) { // NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
// on subclasses. It's never valid to construct a THPVariable so it's not
// necessary to implement the dealloc for that case
static void THPVariable_subclass_dealloc(PyObject* self) {
if (THPVariable_tryResurrect((THPVariable*)self))
return;
// This is like a crappy version of subtype_dealloc.
// Unfortunately, we cannot directly delegate to
// subtype_dealloc as it will start walking the parent
// chain *starting with* the type of self, which will cause
// us to go back to our custom dealloc.
//
// We have to replicate the subtype_dealloc logic to ensure
// that finalizers are handled correctly
PyTypeObject* type = Py_TYPE(self);
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented");
PyObject_GC_UnTrack(self); PyObject_GC_UnTrack(self);
THPVariable_clear((THPVariable*)self); // TODO: consider using trash can
((THPVariable*)self)->cdata.~Variable();
bool has_finalizer = type->tp_finalize || type->tp_del;
if (type->tp_finalize) {
PyObject_GC_Track(self);
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
/* Resurrected */
return;
}
PyObject_GC_UnTrack(self);
}
// base test is unnecessary as THPVariable does not set this
if (type->tp_weaklistoffset) {
PyObject_ClearWeakRefs(self);
}
if (type->tp_del) {
PyObject_GC_Track(self);
type->tp_del(self);
if (Py_REFCNT(self) > 0) {
/* Resurrected */
return;
}
PyObject_GC_UnTrack(self);
}
if (has_finalizer) {
/* New weakrefs could be created during the finalizer call.
If this occurs, clear them out without calling their
finalizers since they might rely on part of the object
being finalized that has already been destroyed. */
if (type->tp_weaklistoffset) {
/* Modeled after GET_WEAKREFS_LISTPTR() */
PyWeakReference** list =
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
while (*list)
_PyWeakref_ClearRef(*list);
}
}
// Clear all slots until we get to base class THPVariableType
{
PyTypeObject* base = type;
while (base != &THPVariableType) {
if (Py_SIZE(base)) {
clear_slots(base, self);
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// All Python defined classes have __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr != nullptr) {
PyObject* dict = *dictptr;
if (dict != nullptr) {
Py_DECREF(dict);
*dictptr = nullptr;
}
}
}
// subtype_dealloc allows for this but we don't
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
// Finally clear out the base THPVariable
THPVariable_subclass_clear((THPVariable*)self);
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
Py_TYPE(self)->tp_free(self); Py_TYPE(self)->tp_free(self);
// Python defined subclasses should always be on the heap
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_DECREF(type);
} }
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) { // Creates a new Python object for a Variable.
TORCH_CHECK_TYPE( static PyObject* THPVariable_NewWithVar(
PyType_Check(cls), PyTypeObject* type,
"cls must be a type (got ", const at::TensorBase& _var,
Py_TYPE(cls)->tp_name, bool allow_preexisting_pyobj,
")"); std::optional<bool> has_torch_dispatch_if_known) {
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls); // Make sure that the reinterpret into a THPVariable* will be valid
TORCH_CHECK_TYPE( TORCH_CHECK(
type == &THPVariableType || cls == THPVariableClass || type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType),
PyType_IsSubtype(type, &THPVariableType), "Creating a Tensor subclass from a class ",
"Creating a Tensor subclass from a class that does not inherit from " "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
"Tensor is not possible. Make sure your class inherits from Tensor.");
// This function overwrite the Tensor's pyobj field without extra checks
// Make sure it is not set otherwise we would leak memory
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
// Under some circumstances, we may attempt to create a new Python
// object for a variable that already has a Python object. The most common
// situation this can occur is if you have a TorchDispatchMode active that
// is returning a subclass from lift_fresh (which is invoked to
// appropriately "wrap" a constant tensor into whatever ambient modes are
// active.)
//
// In general, it is impossible to handle this case compositionally.
// Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
// that is transforming all ops (including the internal lift_fresh call that
// transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
// BTensor, where ATensor and BTensor are completely unrelated subclasses
// and there is no way to compose them. There is no way to satisfy the user
// request here: in particular, you can't just try to re-invoke the ATensor
// constructor on the returned BTensor, because (1) this could cause an
// infinite loop--we are already in ATensor.__new__ and (2) there isn't any
// guarantee that ATensor.__new__ supports a single element constructor
// anyway.
//
// However, a more common case is a user just called torch.Tensor([1, 2, 3]),
// and a fake tensor mode is active. Really, all you want is to get back
// a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
// would have returned a fake tensor (concretely, the way this happens
// is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
// turns into a FakeTensor when we call lift_fresh on this real tensor).
// This case is compositional because FakeTensor is a subclass of Tensor, so
// it's valid for us to return it in place of a Tensor. So this is what we
// do.
if (mb_obj.has_value() && mb_obj.value()) {
TORCH_CHECK(
allow_preexisting_pyobj,
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name);
// Even if we allow pre-existing PyObject, we don't allow completely
// ignoring the requested type. Check that we fulfilled a subtype
// relation here. In the common case the requested type is Tensor and
// this always succeeds.
PyObject* obj = *mb_obj;
// Check if it's OK to just directly return the Python object without
// allocating a new variable. We just check that the existing Python
// object is a subclass of the requested type.
PyTypeObject* obj_type = Py_TYPE(obj);
TORCH_CHECK(
obj_type == type || PyType_IsSubtype(obj_type, type),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name,
" which is not a subclass of the "
"requested type");
// We may (in fact, we typically will) need to resurrect this
return THPVariable_Wrap(_var);
}
PyObject* obj = type->tp_alloc(type, 0);
if (obj) {
auto v = (THPVariable*)obj;
// TODO: named constructor to avoid default initialization
new (&v->cdata) MaybeOwned<Variable>();
if (c10::impl::HermeticPyObjectTLS::get_state()) {
// Do NOT initialize pyobj field on the tensor, you own the C++
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
TORCH_INTERNAL_ASSERT(
!check_has_torch_dispatch(obj),
"While HermeticPyObject was enabled, we attempted to create a tensor "
"subclass with __torch_dispatch__. This violates the invariant that "
"operations in HermeticPyObject have equivalent C++ implementations. "
"If your operator registered from Python operator registration isn't "
"doing anything strange, there may be an internal PyTorch bug involving "
"not appropriately disabling TorchDispatchMode before executing "
"Python op registration.");
} else {
// Normal codepath
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
const auto& var = THPVariable_Unpack(v);
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
if (has_torch_dispatch_if_known.has_value()
? *has_torch_dispatch_if_known
: check_has_torch_dispatch(obj)) {
var.unsafeGetTensorImpl()->set_python_dispatch(true);
}
}
}
return obj;
} }
/// NOTE [ PyObject Traversal ] /// NOTE [ PyObject Traversal ]
@ -3491,7 +3855,7 @@ static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
/// into account these C++ ownership links. /// into account these C++ ownership links.
/// ///
/// The main danger here comes from the fact that, while all python-related code /// The main danger here comes from the fact that, while all python-related code
/// is thread safe wrt the GC execution, other threads might /// is thread safe wrt the GC execution (thanks to the GIL), other threads might
/// be using our C++ objects arbitrarily which can lead to shared_ptr ref count /// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
/// going up or down in between the different traverse/clear invocations. The /// going up or down in between the different traverse/clear invocations. The
/// one constraint we add here that is not explicitly mentioned in the GC /// one constraint we add here that is not explicitly mentioned in the GC
@ -3521,46 +3885,124 @@ static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
/// https://github.com/pytorch/pytorch/issues/7343 /// https://github.com/pytorch/pytorch/issues/7343
/// ///
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg) { static int traverse_slots(
PyTypeObject* type,
PyObject* self,
visitproc visit,
void* arg) {
auto n = Py_SIZE(type);
auto mp = type->tp_members;
for (Py_ssize_t i = 0; i < n; i++, mp++) {
if (mp->type == T_OBJECT_EX) {
char* addr = (char*)self + mp->offset;
PyObject* obj = *(PyObject**)addr;
if (obj != nullptr) {
int err = visit(obj, arg);
if (err)
return err;
}
}
}
return 0;
}
static int THPVariable_subclass_traverse(
PyObject* self,
visitproc visit,
void* arg) {
// If the tensor is eligible to be resurrected, don't traverse it; instead
// treat all of its references as a root (as they WOULD be a root since we
// can treat the inbound C++ references as root owners).
//
// This works because unlike conventional GCs, Python's GC operates in two
// phases: first it uses traverse to discover roots, and then it uses traverse
// to do reachability. Bypassing traverse during root discovery forces Python
// to treat self as a root for everything it refers to. For a full
// explanation of the algorithm see
// https://devguide.python.org/garbage_collector/
//
// NB: if we don't hold an owning reference to the underlying Tensor, it is
// possible that the underlying Tensor has already gone dead. In that case,
// it's not safe to access it. But it's also safe to traverse, because if
// the underlying Tensor *is* live, then root discovery will determine that
// self is live, and nothing will get GC'ed anyway (resurrection cannot happen
// if the C++ objects owns the PyObject)
THPVariable* var = reinterpret_cast<THPVariable*>(self); THPVariable* var = reinterpret_cast<THPVariable*>(self);
if (isResurrectable(var)) {
return 0;
}
// Crappy version of subtype_traverse; same deal as
// THPVariable_subclass_dealloc
PyTypeObject* type = Py_TYPE(self);
// Traverse slots until we get to base class THPVariableType
{
PyTypeObject* base = type;
while (base != &THPVariableType) {
if (Py_SIZE(base)) {
int err = traverse_slots(base, self, visit, arg);
if (err)
return err;
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// All Python defined classes have __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr && *dictptr)
Py_VISIT(*dictptr);
}
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_VISIT(type);
// Finally traverse THPVariable special stuff
Py_VISIT(var->backward_hooks); Py_VISIT(var->backward_hooks);
Py_VISIT(var->post_accumulate_grad_hooks); Py_VISIT(var->post_accumulate_grad_hooks);
const auto& tensor = THPVariable_Unpack(var); if (!var->cdata.unsafeIsBorrowed()) {
if (tensor.defined()) { const auto& tensor = THPVariable_Unpack(var);
// WARNING: The grad_fn traversal logic is very subtle, if you change if (tensor.defined()) {
// this, be very careful not to re-introduce this bug: // WARNING: The grad_fn traversal logic is very subtle, if you change
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c // this, be very careful not to re-introduce this bug:
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking // We ensure that we follow NOTE [ PyObject Traversal ] he by checking
// that this python object is the sole owner of the underlying Tensor and // that this python object is the sole owner of the underlying Tensor and
// that this Tensor is the sole owner of its grad_fn. In this case, the // that this Tensor is the sole owner of its grad_fn. In this case, the
// only way to get a new reference to the grad_fn is by using this python // only way to get a new reference to the grad_fn is by using this python
// object, which requires the GIL to be accessed. Note that this is only // object, which requires the GIL to be accessed. Note that this is only
// valid as long as user don't share non-owning references across // valid as long as user don't share non-owning references across
// different threads (which is crazy and should never be done). // different threads (which is crazy and should never be done).
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
if (tensor.use_count() == 1) { if (tensor.use_count() == 1) {
if (autograd_meta) {
// Do NOT call grad_fn() here as that might trigger a recompute
const auto& grad_fn = autograd_meta->grad_fn_;
if (grad_fn && grad_fn.use_count() == 1) {
// All Node can have a pyobj (stored in "pyobj_")
Py_VISIT(grad_fn->pyobj());
// PyNode are special as they also have an "obj" field
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
Py_VISIT(py_node_fn->obj);
}
}
}
}
if (autograd_meta) { if (autograd_meta) {
// Do NOT call grad_fn() here as that might trigger a recompute for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
const auto& grad_fn = autograd_meta->grad_fn_; if (auto pyhook =
if (grad_fn && grad_fn.use_count() == 1) { dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
// All Node can have a pyobj (stored in "pyobj_") Py_VISIT(pyhook->dict);
Py_VISIT(grad_fn->pyobj());
// PyNode are special as they also have an "obj" field
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
Py_VISIT(py_node_fn->obj);
} }
} }
} }
} }
if (autograd_meta) {
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
} }
return 0; return 0;
} }
@ -3568,6 +4010,17 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) { if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1; return -1;
} }
// It is important for all three of these to be overridden correctly for the
// resurrection checks to properly happen. In particular, an older version
// was not overriding tp_clear here. This lead to the default subtype_clear
// running on the Tensor object (as only TensorBase tp_clear was custom),
// clearing the __dict__ field, before the TensorBase custom clear was called
// and would properly detect the resurrect.
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
((PyTypeObject*)cls)->tp_traverse =
(traverseproc)THPVariable_subclass_traverse;
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
// Don't do anything for the base Tensor class // Don't do anything for the base Tensor class
if (!THPVariableClass) { if (!THPVariableClass) {

View File

@ -17,7 +17,7 @@ namespace py = pybind11;
struct THPVariable { struct THPVariable {
PyObject_HEAD PyObject_HEAD
// Payload // Payload
at::Tensor cdata; c10::MaybeOwned<at::Tensor> cdata;
// Hooks to be run on backwards pass (corresponds to Python attr // Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook') // '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks = nullptr; PyObject* backward_hooks = nullptr;
@ -37,11 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass;
TORCH_PYTHON_API extern PyObject* ParameterClass; TORCH_PYTHON_API extern PyObject* ParameterClass;
bool THPVariable_initModule(PyObject* module); bool THPVariable_initModule(PyObject* module);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase&& var);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var); TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(
const at::TensorBase& var,
PyTypeObject* type);
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
// Check that a python object is a `Tensor`, but not a `Tensor` subclass. // Check that a python object is a `Tensor`, but not a `Tensor` subclass.
@ -73,7 +69,7 @@ inline bool THPVariable_Check(PyObject* obj) {
} }
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
return var->cdata; return *var->cdata;
} }
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {

View File

@ -65,9 +65,7 @@ inline at::Tensor clone_obey_contract(
.new_empty_strided_symint( .new_empty_strided_symint(
variable.sym_sizes(), variable.sym_sizes(),
variable.sym_strides(), variable.sym_strides(),
variable.options() variable.options().memory_format(std::nullopt))
.memory_format(std::nullopt)
.dtype(new_grad.dtype()))
.copy_(new_grad)); .copy_(new_grad));
} else { } else {
// (2) // (2)

View File

@ -70,10 +70,6 @@ inline PyObject* wrap(const at::Tensor& tensor) {
return THPVariable_Wrap(tensor); return THPVariable_Wrap(tensor);
} }
inline PyObject* wrap(at::Tensor&& tensor) {
return THPVariable_Wrap(std::move(tensor));
}
inline PyObject* wrap(const at::Scalar& scalar) { inline PyObject* wrap(const at::Scalar& scalar) {
return wrap(scalar_to_tensor(scalar)); return wrap(scalar_to_tensor(scalar));
} }

View File

@ -197,22 +197,6 @@ TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
TORCH_API void create_cpp_hook( TORCH_API void create_cpp_hook(
const at::TensorBase& /*self*/, const at::TensorBase& /*self*/,
bool is_retains_grad_hooks = false); bool is_retains_grad_hooks = false);
inline bool is_tensor_stealable(
const at::Tensor& new_grad,
size_t num_expected_refs = 1) {
size_t use_count = new_grad.use_count();
if (use_count <= num_expected_refs) {
return true;
}
if (use_count >= 2 &&
new_grad.unsafeGetTensorImpl()->pyobj_slot()->has_unique_reference()) {
// The Python wrapper, if it exists, also has a reference to the Tensor.
num_expected_refs++;
}
return use_count <= num_expected_refs;
}
} // namespace impl } // namespace impl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -910,7 +894,7 @@ inline Variable make_variable(
bool requires_grad = false, bool requires_grad = false,
bool allow_tensor_metadata_change = true) { bool allow_tensor_metadata_change = true) {
if (data.defined()) { if (data.defined()) {
if (impl::is_tensor_stealable(data) && if (data.getIntrusivePtr().use_count() == 1 &&
data.getIntrusivePtr()->unique_version()) { data.getIntrusivePtr()->unique_version()) {
auto data_impl = data.unsafeReleaseIntrusivePtr(); auto data_impl = data.unsafeReleaseIntrusivePtr();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);

View File

@ -1,67 +1,19 @@
#include <torch/csrc/utils/pyobject_preservation.h> #include <torch/csrc/utils/pyobject_preservation.h>
#include <c10/core/impl/PyObjectSlot.h> #include <structmember.h>
#include <c10/util/intrusive_ptr.h>
namespace torch::utils { void clear_slots(PyTypeObject* type, PyObject* self) {
Py_ssize_t n = Py_SIZE(type);
PyMemberDef* mp = type->tp_members;
using c10::intrusive_ptr_target; for (Py_ssize_t i = 0; i < n; i++, mp++) {
using c10::impl::PyObjectSlot; if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
char* addr = (char*)self + mp->offset;
void PyObjectPreservation::init_fresh_nonatomic( PyObject* obj = *(PyObject**)addr;
intrusive_ptr_target* target, if (obj != nullptr) {
PyObjectSlot* slot, *(PyObject**)addr = nullptr;
PyObject* pyobj) { Py_DECREF(obj);
TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr); }
TORCH_INTERNAL_ASSERT(
target->combined_refcount_.load(std::memory_order_relaxed) ==
c10::detail::kUniqueRef);
slot->pyobj_.store(pyobj, std::memory_order_relaxed);
slot->pyobj_interpreter_.store(
c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed);
target->combined_refcount_.store(
c10::detail::kHasPyObject | c10::detail::kUniqueRef,
std::memory_order_relaxed);
}
PyObject* PyObjectPreservation::init_once(
intrusive_ptr_target* target,
PyObjectSlot* slot,
PyObject* pyobj) {
PyObject* expected = nullptr;
if (!slot->pyobj_.compare_exchange_strong(
expected, pyobj, std::memory_order_acq_rel)) {
TORCH_INTERNAL_ASSERT(expected != nullptr);
return expected;
}
slot->pyobj_interpreter_.store(
c10::impl::getGlobalPyInterpreter(), std::memory_order_release);
bool increfed = false;
auto combined = target->combined_refcount_.load(std::memory_order_relaxed);
do {
TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined));
if (c10::detail::refcount(combined) > 1 && !increfed) {
// We need to incref the object to preserve the invariant that
// if refcount > 1, the c10 object holds a reference to the PyObject.
// This must happen before we set the kHasPyObject bit.
Py_INCREF(pyobj);
increfed = true;
} }
} while (!target->combined_refcount_.compare_exchange_weak(
combined,
combined | c10::detail::kHasPyObject,
std::memory_order_acq_rel,
std::memory_order_relaxed));
if (increfed && c10::detail::refcount(combined) == 1) {
// Fix up if refcount if we did the incref in a failed compare-exchange
Py_DECREF(pyobj);
} }
return pyobj;
} }
} // namespace torch::utils

View File

@ -4,28 +4,4 @@
// This file contains utilities used for handling PyObject preservation // This file contains utilities used for handling PyObject preservation
namespace c10 { void clear_slots(PyTypeObject* type, PyObject* self);
class intrusive_ptr_target;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::utils {
class PyObjectPreservation {
public:
// Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold
// a unique reference to `target`.
static void init_fresh_nonatomic(
c10::intrusive_ptr_target* target,
c10::impl::PyObjectSlot* slot,
PyObject* pyobj);
static PyObject* init_once(
c10::intrusive_ptr_target* target,
c10::impl::PyObjectSlot* slot,
PyObject* pyobj);
};
} // namespace torch::utils

View File

@ -207,19 +207,12 @@ def tensorify_python_scalars(
and node.target is torch.ops.aten._local_scalar_dense.default and node.target is torch.ops.aten._local_scalar_dense.default
): ):
dtype = node.args[0].meta["val"].dtype dtype = node.args[0].meta["val"].dtype
if not dtype.is_floating_point:
continue
assert isinstance(node.args[0], fx.Node), node.args[0] assert isinstance(node.args[0], fx.Node), node.args[0]
s = node.meta["val"].node.expr s = node.meta["val"].node.expr
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# only tensorify if the dtype is floating point
if not dtype.is_floating_point:
continue
expr_to_tensor_proxy[s] = MetaProxy( expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode node.args[0], tracer=tracer, fake_mode=fake_mode
) )
@ -227,7 +220,9 @@ def tensorify_python_scalars(
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default( expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
expr_to_tensor_proxy[s], torch.float64 expr_to_tensor_proxy[s], torch.float64
) )
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# pyrefly: ignore [bad-argument-type] # pyrefly: ignore [bad-argument-type]
elif (sym_expr := _get_sym_val(node)) is not None: elif (sym_expr := _get_sym_val(node)) is not None:
if sym_expr not in expr_to_sym_proxy and not isinstance( if sym_expr not in expr_to_sym_proxy and not isinstance(

View File

@ -387,7 +387,7 @@ class DTensorTestBase(MultiProcessTestCase):
@property @property
def backend(self) -> str: def backend(self) -> str:
backend = dist.get_default_backend_for_device(DEVICE_TYPE) backend = dist.get_default_backend_for_device(self.device_type)
return backend return backend
def init_manual_seed_for_rank(self) -> None: def init_manual_seed_for_rank(self) -> None: