Combine strong and weak refcounts in intrusive_ptr in a single refcount (#163394)

Summary:
Currently, we assume that refcount_ and weakcount_ are always stored in an 8-byte aligned address right next to each other. Based on this assumption, we load 8 bytes in intrusive_ptr::reset_ to check the values of both counts. However, that assumption is not part of C++ language standard so it's essentially undefined behavior.

This change eliminates that assumption by combining refcount_ and weakcount_ in a single 64-bit count and we use the lower 32 bits for refcount_ and upper 32 bits for the weakcount_.

In addition to eliminating the undefined behavior, the change also eliminates the read of weakcount_ after decrementing refcount_ in intrusive_ptr::reset_. This claws back lost performance introduced in https://github.com/pytorch/pytorch/pull/162784 for non-final refcount_ decrementing.

Reviewed By: yfeldblum

Differential Revision: D82869192

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163394
Approved by: https://github.com/Skylion007
This commit is contained in:
Ben Niu
2025-09-22 17:53:28 +00:00
committed by PyTorch MergeBot
parent 5e7be98800
commit 281f8f407e
2 changed files with 145 additions and 120 deletions

View File

@ -35,26 +35,26 @@ struct ExclusivelyOwnedTensorTraits {
// incremented.
const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined),
toDestroy->refcount() == 1 ||
(toDestroy->refcount() == 0 && isUndefined),
"ExclusivelyOwned<Tensor> destroyed with isUndefined ",
isUndefined,
" and refcount ",
toDestroy->refcount_,
toDestroy->refcount(),
", expected 1 or, if isUndefined, 0!");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
toDestroy->weakcount_ == 1 ||
(toDestroy->weakcount_ == 0 &&
toDestroy->weakcount() == 1 ||
(toDestroy->weakcount() == 0 &&
toDestroy == UndefinedTensorImpl::singleton()),
"ExclusivelyOwned<Tensor> destroyed with isUndefined ",
isUndefined,
" and weakcount ",
toDestroy->weakcount_,
toDestroy->weakcount(),
", expected 1 or, if isUndefined, 0!");
if (!isUndefined) {
#ifndef NDEBUG
// Needed to pass the debug assertions in ~intrusive_ptr_target.
toDestroy->refcount_ = 0;
toDestroy->weakcount_ = 0;
toDestroy->combined_refcount_.store(0, std::memory_order_relaxed);
#endif
delete toDestroy;
}

View File

@ -27,7 +27,78 @@ struct DontIncreaseRefcount {};
} // namespace raw
namespace detail {
constexpr uint32_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF;
constexpr uint64_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF;
constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
(kImpracticallyHugeReferenceCount << 32);
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
template <class TTarget>
struct intrusive_target_default_null_type final {
static constexpr TTarget* singleton() noexcept {
return nullptr;
}
};
template <class TTarget, class ToNullType, class FromNullType>
TTarget* assign_ptr_(TTarget* rhs) {
if (FromNullType::singleton() == rhs) {
return ToNullType::singleton();
} else {
return rhs;
}
}
inline uint32_t refcount(uint64_t combined_refcount) {
return static_cast<uint32_t>(combined_refcount);
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>(combined_refcount >> 32);
}
// The only requirement for refcount increment is that it happens-before
// decrement, so no additional memory ordering is needed.
inline uint64_t atomic_combined_refcount_increment(
std::atomic<uint64_t>& combined_refcount,
uint64_t inc) {
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
combined_refcount, kWeakReferenceCountOne));
}
// The requirement is that all modifications to the managed object happen-before
// invocation of the managed object destructor, and that allocation of the
// managed object storage happens-before deallocation of the storage.
//
// To get this ordering, all non-final decrements must synchronize-with the
// final decrement. So all non-final decrements have to store-release while the
// final decrement has to load-acquire, either directly or with the help of
// fences. But it's easiest just to have all decrements be acq-rel. And it turns
// out, on modern architectures and chips, it's also fastest.
inline uint64_t atomic_combined_refcount_decrement(
std::atomic<uint64_t>& combined_refcount,
uint64_t dec) {
return combined_refcount.fetch_sub(dec, std::memory_order_acq_rel) - dec;
}
inline uint32_t atomic_weakcount_decrement(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_decrement(
combined_refcount, kWeakReferenceCountOne));
}
} // namespace detail
/**
@ -80,8 +151,14 @@ class C10_API intrusive_ptr_target {
// atomically increment the use count, if it is greater than 0.
// If it is not, you must report that the storage is dead.
//
mutable std::atomic<uint32_t> refcount_;
mutable std::atomic<uint32_t> weakcount_;
//.We use a single combined count for refcount and weakcount so that
// we can atomically operate on both at the same time for performance
// and defined behaviors.
//
mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
static_assert(std::atomic<uint64_t>::is_always_lock_free);
template <typename T, typename NullType>
friend class intrusive_ptr;
@ -126,16 +203,16 @@ class C10_API intrusive_ptr_target {
// caller of unsafe_adapt_non_heap_allocated wanted to
// use). We choose our reference count such that the count
// will not dip below kImpracticallyHugeReferenceCount regardless.
refcount_.load() == 0 ||
refcount_.load() >= detail::kImpracticallyHugeReferenceCount,
refcount() == 0 ||
refcount() >= detail::kImpracticallyHugeReferenceCount,
"Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ",
refcount_.load());
refcount());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
// See ~intrusive_ptr for optimization that will frequently result in 1
// at destruction time.
weakcount_.load() == 1 || weakcount_.load() == 0 ||
weakcount_.load() == detail::kImpracticallyHugeReferenceCount - 1 ||
weakcount_.load() == detail::kImpracticallyHugeReferenceCount,
weakcount() == 1 || weakcount() == 0 ||
weakcount() == detail::kImpracticallyHugeReferenceCount - 1 ||
weakcount() == detail::kImpracticallyHugeReferenceCount,
"Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
@ -144,7 +221,7 @@ class C10_API intrusive_ptr_target {
#endif
}
constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {}
constexpr intrusive_ptr_target() noexcept : combined_refcount_(0) {}
// intrusive_ptr_target supports copy and move: but refcount and weakcount
// don't participate (since they are intrinsic properties of the memory
@ -177,54 +254,17 @@ class C10_API intrusive_ptr_target {
* destructed), this function WILL NOT be called.
*/
virtual void release_resources() {}
};
namespace detail {
template <class TTarget>
struct intrusive_target_default_null_type final {
static constexpr TTarget* singleton() noexcept {
return nullptr;
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
uint32_t weakcount(
std::memory_order order = std::memory_order_relaxed) const {
return detail::weakcount(combined_refcount_.load(order));
}
};
template <class TTarget, class ToNullType, class FromNullType>
TTarget* assign_ptr_(TTarget* rhs) {
if (FromNullType::singleton() == rhs) {
return ToNullType::singleton();
} else {
return rhs;
}
}
// The only requirement for refcount increment is that it happens-before
// decrement, so no additional memory ordering is needed.
inline uint32_t atomic_refcount_increment(std::atomic<uint32_t>& refcount) {
return refcount.fetch_add(1, std::memory_order_relaxed) + 1;
}
inline uint32_t atomic_weakcount_increment(std::atomic<uint32_t>& weakcount) {
return weakcount.fetch_add(1, std::memory_order_relaxed) + 1;
}
// The requirement is that all modifications to the managed object happen-before
// invocation of the managed object destructor, and that allocation of the
// managed object storage happens-before deallocation of the storage.
//
// To get this ordering, all non-final decrements must synchronize-with the
// final decrement. So all non-final decrements have to store-release while the
// final decrement has to load-acquire, either directly or with the help of
// fences. But it's easiest just to have all decrements be acq-rel. And it turns
// out, on modern architectures and chips, it's also fastest.
inline uint32_t atomic_refcount_decrement(std::atomic<uint32_t>& refcount) {
return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
}
inline uint32_t atomic_weakcount_decrement(std::atomic<uint32_t>& weakcount) {
return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
}
} // namespace detail
template <class TTarget, class NullType>
class weak_intrusive_ptr;
@ -275,7 +315,7 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->refcount_);
detail::atomic_refcount_increment(target_->combined_refcount_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
@ -284,41 +324,25 @@ class intrusive_ptr final {
void reset_() noexcept {
if (target_ != NullType::singleton()) {
#if defined(__linux__) && (defined(__aarch64__) || defined(__x86_64__))
if constexpr (
std::atomic<uint64_t>::is_always_lock_free &&
std::atomic<uint32_t>::is_always_lock_free &&
sizeof(std::atomic<uint64_t>) == 8 &&
sizeof(std::atomic<uint32_t>) == 4) {
auto both_counts_ =
reinterpret_cast<std::atomic<uint64_t>*>(&target_->refcount_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(reinterpret_cast<std::uintptr_t>(both_counts_) %
sizeof(std::atomic<uint64_t>)) == 0 &&
(reinterpret_cast<std::uintptr_t>(&target_->weakcount_) -
reinterpret_cast<std::uintptr_t>(both_counts_)) ==
sizeof(std::atomic<uint32_t>));
// 0x100000001ULL is a 64-bit number combination of both the refcount_
// and weakcount_ being 1.
constexpr uint64_t unique_ref_ = 0x100000001ULL;
if (both_counts_->load(std::memory_order_acquire) == unique_ref_) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
// call (e.g. calling use_count()) without a data race.
target_->refcount_.store(0, std::memory_order_relaxed);
delete target_;
return;
}
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
// call (e.g. calling use_count()) without a data race.
target_->combined_refcount_.store(0, std::memory_order_relaxed);
delete target_;
return;
}
#endif
if (detail::atomic_refcount_decrement(target_->refcount_) == 0) {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
bool should_delete =
target_->weakcount_.load(std::memory_order_acquire) == 1;
if (!should_delete) {
// justification for const_cast: release_resources is basically a
// destructor and a destructor always mutates the object, even for
@ -326,8 +350,8 @@ class intrusive_ptr final {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<std::remove_const_t<TTarget>*>(target_)
->release_resources();
should_delete =
detail::atomic_weakcount_decrement(target_->weakcount_) == 0;
should_delete = detail::atomic_weakcount_decrement(
target_->combined_refcount_) == 0;
}
if (should_delete) {
delete target_;
@ -354,12 +378,12 @@ class intrusive_ptr final {
// `mov`, whereas an atomic increment does a lock-prefixed `add`, which is
// much more expensive: https://godbolt.org/z/eKPzj8.)
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
target_->refcount_ == 0 && target_->weakcount_ == 0,
target_->combined_refcount_.load(std::memory_order_relaxed) == 0,
"intrusive_ptr: Newly-created target had non-zero refcounts. Does its "
"constructor do something strange like incref or create an "
"intrusive_ptr from `this`?");
target_->refcount_.store(1, std::memory_order_relaxed);
target_->weakcount_.store(1, std::memory_order_relaxed);
target_->combined_refcount_.store(
detail::kUniqueRef, std::memory_order_relaxed);
}
}
@ -482,14 +506,14 @@ class intrusive_ptr final {
if (target_ == NullType::singleton()) {
return 0;
}
return target_->refcount_.load(std::memory_order_relaxed);
return target_->refcount(std::memory_order_relaxed);
}
uint32_t weak_use_count() const noexcept {
if (target_ == NullType::singleton()) {
return 0;
}
return target_->weakcount_.load(std::memory_order_relaxed);
return target_->weakcount(std::memory_order_relaxed);
}
bool unique() const noexcept {
@ -518,8 +542,8 @@ class intrusive_ptr final {
*/
static intrusive_ptr reclaim(TTarget* owning_ptr) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
owning_ptr == NullType::singleton() ||
owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(),
owning_ptr == NullType::singleton() || owning_ptr->refcount() == 0 ||
owning_ptr->weakcount(),
"TTarget violates the invariant that refcount > 0 => weakcount > 0");
return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{});
}
@ -590,11 +614,11 @@ class intrusive_ptr final {
#ifdef NDEBUG
expected_decrefs = 0;
#endif
result.target_->refcount_.store(
detail::kImpracticallyHugeReferenceCount + expected_decrefs,
result.target_->combined_refcount_.store(
detail::refcount(
detail::kImpracticallyHugeReferenceCount + expected_decrefs) |
detail::kImpracticallyHugeWeakReferenceCount,
std::memory_order_relaxed);
result.target_->weakcount_.store(
detail::kImpracticallyHugeReferenceCount, std::memory_order_relaxed);
return result;
}
@ -611,7 +635,7 @@ class intrusive_ptr final {
static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
// See Note [Stack allocated intrusive_ptr_target safety]
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
raw_ptr == NullType::singleton() || raw_ptr->refcount() > 0,
"intrusive_ptr: Can only reclaim pointers that are owned by someone");
auto ptr = reclaim(raw_ptr); // doesn't increase refcount
ptr.retain_();
@ -745,7 +769,7 @@ class weak_intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint32_t new_weakcount =
detail::atomic_weakcount_increment(target_->weakcount_);
detail::atomic_weakcount_increment(target_->combined_refcount_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_weakcount != 1,
"weak_intrusive_ptr: Cannot increase weakcount after it reached zero.");
@ -754,7 +778,7 @@ class weak_intrusive_ptr final {
void reset_() noexcept {
if (target_ != NullType::singleton() &&
detail::atomic_weakcount_decrement(target_->weakcount_) == 0) {
detail::atomic_weakcount_decrement(target_->combined_refcount_) == 0) {
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
delete target_;
}
@ -887,7 +911,7 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return 0;
}
return target_->refcount_.load(
return target_->refcount(
std::memory_order_relaxed); // refcount, not weakcount!
}
@ -895,7 +919,7 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return 0;
}
return target_->weakcount_.load(std::memory_order_relaxed);
return target_->weakcount(std::memory_order_relaxed);
}
bool expired() const noexcept {
@ -906,16 +930,17 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>();
} else {
auto refcount = target_->refcount_.load(std::memory_order_relaxed);
auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed);
do {
if (refcount == 0) {
if (detail::refcount(combined_refcount) == 0) {
// Object already destructed, no strong references left anymore.
// Return nullptr.
return intrusive_ptr<TTarget, NullType>();
}
} while (!target_->refcount_.compare_exchange_weak(
refcount,
refcount + 1,
} while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount,
combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire,
std::memory_order_relaxed));
@ -952,9 +977,9 @@ class weak_intrusive_ptr final {
// if refcount == 0, weakcount only must be >0.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
owning_weak_ptr == NullType::singleton() ||
owning_weak_ptr->weakcount_.load() > 1 ||
(owning_weak_ptr->refcount_.load() == 0 &&
owning_weak_ptr->weakcount_.load() > 0),
owning_weak_ptr->weakcount() > 1 ||
(owning_weak_ptr->refcount() == 0 &&
owning_weak_ptr->weakcount() > 0),
"weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release().");
return weak_intrusive_ptr(owning_weak_ptr);
}
@ -1033,7 +1058,7 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
detail::atomic_refcount_increment(self->refcount_);
detail::atomic_refcount_increment(self->combined_refcount_);
}
}
@ -1067,7 +1092,7 @@ inline uint32_t use_count(intrusive_ptr_target* self) {
namespace weak_intrusive_ptr {
inline void incref(weak_intrusive_ptr_target* self) {
detail::atomic_weakcount_increment(self->weakcount_);
detail::atomic_weakcount_increment(self->combined_refcount_);
}
inline void decref(weak_intrusive_ptr_target* self) {