diff --git a/c10/util/ExclusivelyOwnedTensorTraits.h b/c10/util/ExclusivelyOwnedTensorTraits.h index 73ff45b8c38d..f19df3089f77 100644 --- a/c10/util/ExclusivelyOwnedTensorTraits.h +++ b/c10/util/ExclusivelyOwnedTensorTraits.h @@ -35,26 +35,26 @@ struct ExclusivelyOwnedTensorTraits { // incremented. const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined), + toDestroy->refcount() == 1 || + (toDestroy->refcount() == 0 && isUndefined), "ExclusivelyOwned destroyed with isUndefined ", isUndefined, " and refcount ", - toDestroy->refcount_, + toDestroy->refcount(), ", expected 1 or, if isUndefined, 0!"); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - toDestroy->weakcount_ == 1 || - (toDestroy->weakcount_ == 0 && + toDestroy->weakcount() == 1 || + (toDestroy->weakcount() == 0 && toDestroy == UndefinedTensorImpl::singleton()), "ExclusivelyOwned destroyed with isUndefined ", isUndefined, " and weakcount ", - toDestroy->weakcount_, + toDestroy->weakcount(), ", expected 1 or, if isUndefined, 0!"); if (!isUndefined) { #ifndef NDEBUG // Needed to pass the debug assertions in ~intrusive_ptr_target. - toDestroy->refcount_ = 0; - toDestroy->weakcount_ = 0; + toDestroy->combined_refcount_.store(0, std::memory_order_relaxed); #endif delete toDestroy; } diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 449910cbb29e..1f89b2799ad6 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -27,7 +27,78 @@ struct DontIncreaseRefcount {}; } // namespace raw namespace detail { -constexpr uint32_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF; +constexpr uint64_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF; +constexpr uint64_t kImpracticallyHugeWeakReferenceCount = + (kImpracticallyHugeReferenceCount << 32); +constexpr uint64_t kReferenceCountOne = 1; +constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32); +constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne); + +template +struct intrusive_target_default_null_type final { + static constexpr TTarget* singleton() noexcept { + return nullptr; + } +}; + +template +TTarget* assign_ptr_(TTarget* rhs) { + if (FromNullType::singleton() == rhs) { + return ToNullType::singleton(); + } else { + return rhs; + } +} + +inline uint32_t refcount(uint64_t combined_refcount) { + return static_cast(combined_refcount); +} + +inline uint32_t weakcount(uint64_t combined_refcount) { + return static_cast(combined_refcount >> 32); +} + +// The only requirement for refcount increment is that it happens-before +// decrement, so no additional memory ordering is needed. +inline uint64_t atomic_combined_refcount_increment( + std::atomic& combined_refcount, + uint64_t inc) { + return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc; +} + +inline uint32_t atomic_refcount_increment( + std::atomic& combined_refcount) { + return detail::refcount(atomic_combined_refcount_increment( + combined_refcount, kReferenceCountOne)); +} + +inline uint32_t atomic_weakcount_increment( + std::atomic& combined_refcount) { + return detail::weakcount(atomic_combined_refcount_increment( + combined_refcount, kWeakReferenceCountOne)); +} + +// The requirement is that all modifications to the managed object happen-before +// invocation of the managed object destructor, and that allocation of the +// managed object storage happens-before deallocation of the storage. +// +// To get this ordering, all non-final decrements must synchronize-with the +// final decrement. So all non-final decrements have to store-release while the +// final decrement has to load-acquire, either directly or with the help of +// fences. But it's easiest just to have all decrements be acq-rel. And it turns +// out, on modern architectures and chips, it's also fastest. +inline uint64_t atomic_combined_refcount_decrement( + std::atomic& combined_refcount, + uint64_t dec) { + return combined_refcount.fetch_sub(dec, std::memory_order_acq_rel) - dec; +} + +inline uint32_t atomic_weakcount_decrement( + std::atomic& combined_refcount) { + return detail::weakcount(atomic_combined_refcount_decrement( + combined_refcount, kWeakReferenceCountOne)); +} + } // namespace detail /** @@ -80,8 +151,14 @@ class C10_API intrusive_ptr_target { // atomically increment the use count, if it is greater than 0. // If it is not, you must report that the storage is dead. // - mutable std::atomic refcount_; - mutable std::atomic weakcount_; + //.We use a single combined count for refcount and weakcount so that + // we can atomically operate on both at the same time for performance + // and defined behaviors. + // + mutable std::atomic combined_refcount_; + static_assert(sizeof(std::atomic) == 8); + static_assert(alignof(std::atomic) == 8); + static_assert(std::atomic::is_always_lock_free); template friend class intrusive_ptr; @@ -126,16 +203,16 @@ class C10_API intrusive_ptr_target { // caller of unsafe_adapt_non_heap_allocated wanted to // use). We choose our reference count such that the count // will not dip below kImpracticallyHugeReferenceCount regardless. - refcount_.load() == 0 || - refcount_.load() >= detail::kImpracticallyHugeReferenceCount, + refcount() == 0 || + refcount() >= detail::kImpracticallyHugeReferenceCount, "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ", - refcount_.load()); + refcount()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( // See ~intrusive_ptr for optimization that will frequently result in 1 // at destruction time. - weakcount_.load() == 1 || weakcount_.load() == 0 || - weakcount_.load() == detail::kImpracticallyHugeReferenceCount - 1 || - weakcount_.load() == detail::kImpracticallyHugeReferenceCount, + weakcount() == 1 || weakcount() == 0 || + weakcount() == detail::kImpracticallyHugeReferenceCount - 1 || + weakcount() == detail::kImpracticallyHugeReferenceCount, "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it"); #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) @@ -144,7 +221,7 @@ class C10_API intrusive_ptr_target { #endif } - constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {} + constexpr intrusive_ptr_target() noexcept : combined_refcount_(0) {} // intrusive_ptr_target supports copy and move: but refcount and weakcount // don't participate (since they are intrinsic properties of the memory @@ -177,54 +254,17 @@ class C10_API intrusive_ptr_target { * destructed), this function WILL NOT be called. */ virtual void release_resources() {} -}; -namespace detail { -template -struct intrusive_target_default_null_type final { - static constexpr TTarget* singleton() noexcept { - return nullptr; + uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const { + return detail::refcount(combined_refcount_.load(order)); + } + + uint32_t weakcount( + std::memory_order order = std::memory_order_relaxed) const { + return detail::weakcount(combined_refcount_.load(order)); } }; -template -TTarget* assign_ptr_(TTarget* rhs) { - if (FromNullType::singleton() == rhs) { - return ToNullType::singleton(); - } else { - return rhs; - } -} - -// The only requirement for refcount increment is that it happens-before -// decrement, so no additional memory ordering is needed. -inline uint32_t atomic_refcount_increment(std::atomic& refcount) { - return refcount.fetch_add(1, std::memory_order_relaxed) + 1; -} - -inline uint32_t atomic_weakcount_increment(std::atomic& weakcount) { - return weakcount.fetch_add(1, std::memory_order_relaxed) + 1; -} - -// The requirement is that all modifications to the managed object happen-before -// invocation of the managed object destructor, and that allocation of the -// managed object storage happens-before deallocation of the storage. -// -// To get this ordering, all non-final decrements must synchronize-with the -// final decrement. So all non-final decrements have to store-release while the -// final decrement has to load-acquire, either directly or with the help of -// fences. But it's easiest just to have all decrements be acq-rel. And it turns -// out, on modern architectures and chips, it's also fastest. -inline uint32_t atomic_refcount_decrement(std::atomic& refcount) { - return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1; -} - -inline uint32_t atomic_weakcount_decrement(std::atomic& weakcount) { - return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1; -} - -} // namespace detail - template class weak_intrusive_ptr; @@ -275,7 +315,7 @@ class intrusive_ptr final { void retain_() { if (target_ != NullType::singleton()) { uint32_t new_refcount = - detail::atomic_refcount_increment(target_->refcount_); + detail::atomic_refcount_increment(target_->combined_refcount_); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( new_refcount != 1, "intrusive_ptr: Cannot increase refcount after it reached zero."); @@ -284,41 +324,25 @@ class intrusive_ptr final { void reset_() noexcept { if (target_ != NullType::singleton()) { -#if defined(__linux__) && (defined(__aarch64__) || defined(__x86_64__)) - if constexpr ( - std::atomic::is_always_lock_free && - std::atomic::is_always_lock_free && - sizeof(std::atomic) == 8 && - sizeof(std::atomic) == 4) { - auto both_counts_ = - reinterpret_cast*>(&target_->refcount_); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - (reinterpret_cast(both_counts_) % - sizeof(std::atomic)) == 0 && - (reinterpret_cast(&target_->weakcount_) - - reinterpret_cast(both_counts_)) == - sizeof(std::atomic)); - // 0x100000001ULL is a 64-bit number combination of both the refcount_ - // and weakcount_ being 1. - constexpr uint64_t unique_ref_ = 0x100000001ULL; - if (both_counts_->load(std::memory_order_acquire) == unique_ref_) { - // Both counts are 1, so there are no weak references and - // we are releasing the last strong reference. No other - // threads can observe the effects of this target_ deletion - // call (e.g. calling use_count()) without a data race. - target_->refcount_.store(0, std::memory_order_relaxed); - delete target_; - return; - } + if (target_->combined_refcount_.load(std::memory_order_acquire) == + detail::kUniqueRef) { + // Both counts are 1, so there are no weak references and + // we are releasing the last strong reference. No other + // threads can observe the effects of this target_ deletion + // call (e.g. calling use_count()) without a data race. + target_->combined_refcount_.store(0, std::memory_order_relaxed); + delete target_; + return; } -#endif - if (detail::atomic_refcount_decrement(target_->refcount_) == 0) { + auto combined_refcount = detail::atomic_combined_refcount_decrement( + target_->combined_refcount_, detail::kReferenceCountOne); + if (detail::refcount(combined_refcount) == 0) { + bool should_delete = + (combined_refcount == detail::kWeakReferenceCountOne); // See comment above about weakcount. As long as refcount>0, // weakcount is one larger than the actual number of weak references. // So we need to decrement it here. - bool should_delete = - target_->weakcount_.load(std::memory_order_acquire) == 1; if (!should_delete) { // justification for const_cast: release_resources is basically a // destructor and a destructor always mutates the object, even for @@ -326,8 +350,8 @@ class intrusive_ptr final { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast*>(target_) ->release_resources(); - should_delete = - detail::atomic_weakcount_decrement(target_->weakcount_) == 0; + should_delete = detail::atomic_weakcount_decrement( + target_->combined_refcount_) == 0; } if (should_delete) { delete target_; @@ -354,12 +378,12 @@ class intrusive_ptr final { // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is // much more expensive: https://godbolt.org/z/eKPzj8.) TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - target_->refcount_ == 0 && target_->weakcount_ == 0, + target_->combined_refcount_.load(std::memory_order_relaxed) == 0, "intrusive_ptr: Newly-created target had non-zero refcounts. Does its " "constructor do something strange like incref or create an " "intrusive_ptr from `this`?"); - target_->refcount_.store(1, std::memory_order_relaxed); - target_->weakcount_.store(1, std::memory_order_relaxed); + target_->combined_refcount_.store( + detail::kUniqueRef, std::memory_order_relaxed); } } @@ -482,14 +506,14 @@ class intrusive_ptr final { if (target_ == NullType::singleton()) { return 0; } - return target_->refcount_.load(std::memory_order_relaxed); + return target_->refcount(std::memory_order_relaxed); } uint32_t weak_use_count() const noexcept { if (target_ == NullType::singleton()) { return 0; } - return target_->weakcount_.load(std::memory_order_relaxed); + return target_->weakcount(std::memory_order_relaxed); } bool unique() const noexcept { @@ -518,8 +542,8 @@ class intrusive_ptr final { */ static intrusive_ptr reclaim(TTarget* owning_ptr) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - owning_ptr == NullType::singleton() || - owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(), + owning_ptr == NullType::singleton() || owning_ptr->refcount() == 0 || + owning_ptr->weakcount(), "TTarget violates the invariant that refcount > 0 => weakcount > 0"); return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{}); } @@ -590,11 +614,11 @@ class intrusive_ptr final { #ifdef NDEBUG expected_decrefs = 0; #endif - result.target_->refcount_.store( - detail::kImpracticallyHugeReferenceCount + expected_decrefs, + result.target_->combined_refcount_.store( + detail::refcount( + detail::kImpracticallyHugeReferenceCount + expected_decrefs) | + detail::kImpracticallyHugeWeakReferenceCount, std::memory_order_relaxed); - result.target_->weakcount_.store( - detail::kImpracticallyHugeReferenceCount, std::memory_order_relaxed); return result; } @@ -611,7 +635,7 @@ class intrusive_ptr final { static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) { // See Note [Stack allocated intrusive_ptr_target safety] TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0, + raw_ptr == NullType::singleton() || raw_ptr->refcount() > 0, "intrusive_ptr: Can only reclaim pointers that are owned by someone"); auto ptr = reclaim(raw_ptr); // doesn't increase refcount ptr.retain_(); @@ -745,7 +769,7 @@ class weak_intrusive_ptr final { void retain_() { if (target_ != NullType::singleton()) { uint32_t new_weakcount = - detail::atomic_weakcount_increment(target_->weakcount_); + detail::atomic_weakcount_increment(target_->combined_refcount_); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( new_weakcount != 1, "weak_intrusive_ptr: Cannot increase weakcount after it reached zero."); @@ -754,7 +778,7 @@ class weak_intrusive_ptr final { void reset_() noexcept { if (target_ != NullType::singleton() && - detail::atomic_weakcount_decrement(target_->weakcount_) == 0) { + detail::atomic_weakcount_decrement(target_->combined_refcount_) == 0) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete) delete target_; } @@ -887,7 +911,7 @@ class weak_intrusive_ptr final { if (target_ == NullType::singleton()) { return 0; } - return target_->refcount_.load( + return target_->refcount( std::memory_order_relaxed); // refcount, not weakcount! } @@ -895,7 +919,7 @@ class weak_intrusive_ptr final { if (target_ == NullType::singleton()) { return 0; } - return target_->weakcount_.load(std::memory_order_relaxed); + return target_->weakcount(std::memory_order_relaxed); } bool expired() const noexcept { @@ -906,16 +930,17 @@ class weak_intrusive_ptr final { if (target_ == NullType::singleton()) { return intrusive_ptr(); } else { - auto refcount = target_->refcount_.load(std::memory_order_relaxed); + auto combined_refcount = + target_->combined_refcount_.load(std::memory_order_relaxed); do { - if (refcount == 0) { + if (detail::refcount(combined_refcount) == 0) { // Object already destructed, no strong references left anymore. // Return nullptr. return intrusive_ptr(); } - } while (!target_->refcount_.compare_exchange_weak( - refcount, - refcount + 1, + } while (!target_->combined_refcount_.compare_exchange_weak( + combined_refcount, + combined_refcount + detail::kReferenceCountOne, std::memory_order_acquire, std::memory_order_relaxed)); @@ -952,9 +977,9 @@ class weak_intrusive_ptr final { // if refcount == 0, weakcount only must be >0. TORCH_INTERNAL_ASSERT_DEBUG_ONLY( owning_weak_ptr == NullType::singleton() || - owning_weak_ptr->weakcount_.load() > 1 || - (owning_weak_ptr->refcount_.load() == 0 && - owning_weak_ptr->weakcount_.load() > 0), + owning_weak_ptr->weakcount() > 1 || + (owning_weak_ptr->refcount() == 0 && + owning_weak_ptr->weakcount() > 0), "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release()."); return weak_intrusive_ptr(owning_weak_ptr); } @@ -1033,7 +1058,7 @@ namespace intrusive_ptr { // NullType::singleton to this function inline void incref(intrusive_ptr_target* self) { if (self) { - detail::atomic_refcount_increment(self->refcount_); + detail::atomic_refcount_increment(self->combined_refcount_); } } @@ -1067,7 +1092,7 @@ inline uint32_t use_count(intrusive_ptr_target* self) { namespace weak_intrusive_ptr { inline void incref(weak_intrusive_ptr_target* self) { - detail::atomic_weakcount_increment(self->weakcount_); + detail::atomic_weakcount_increment(self->combined_refcount_); } inline void decref(weak_intrusive_ptr_target* self) {