diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 288b19df0a6c..de81d4c1b7df 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -196,20 +196,25 @@ TTarget* assign_ptr_(TTarget* rhs) { } } -// Increment needs to be acquire-release to make use_count() and -// unique() reliable. +// The only requirement for refcount increment is that it happens-before +// decrement, so no additional memory ordering is needed. inline uint32_t atomic_refcount_increment(std::atomic& refcount) { - return refcount.fetch_add(1, std::memory_order_acq_rel) + 1; + return refcount.fetch_add(1, std::memory_order_relaxed) + 1; } -// weak_use_count() is only used for testing, so we don't need it to -// be reliable. Relaxed should be fine. inline uint32_t atomic_weakcount_increment(std::atomic& weakcount) { return weakcount.fetch_add(1, std::memory_order_relaxed) + 1; } -// Both decrements need to be acquire-release for correctness. See -// e.g. std::shared_ptr implementation. +// The requirement is that all modifications to the managed object happen-before +// invocation of the managed object destructor, and that allocation of the +// managed object storage happens-before deallocation of the storage. +// +// To get this ordering, all non-final decrements must synchronize-with the +// final decrement. So all non-final decrements have to store-release while the +// final decrement has to load-acquire, either directly or with the help of +// fences. But it's easiest just to have all decrements be acq-rel. And it turns +// out, on modern architectures and chips, it's also fastest. inline uint32_t atomic_refcount_decrement(std::atomic& refcount) { return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1; } @@ -332,7 +337,7 @@ class intrusive_ptr final { intrusive_ptr() noexcept : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} - intrusive_ptr(std::nullptr_t) noexcept + /* implicit */ intrusive_ptr(std::nullptr_t) noexcept : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} // This constructor will not increase the ref counter for you. @@ -445,14 +450,14 @@ class intrusive_ptr final { if (target_ == NullType::singleton()) { return 0; } - return target_->refcount_.load(std::memory_order_acquire); + return target_->refcount_.load(std::memory_order_relaxed); } uint32_t weak_use_count() const noexcept { if (target_ == NullType::singleton()) { return 0; } - return target_->weakcount_.load(std::memory_order_acquire); + return target_->weakcount_.load(std::memory_order_relaxed); } bool unique() const noexcept { @@ -851,14 +856,14 @@ class weak_intrusive_ptr final { return 0; } return target_->refcount_.load( - std::memory_order_acquire); // refcount, not weakcount! + std::memory_order_relaxed); // refcount, not weakcount! } uint32_t weak_use_count() const noexcept { if (target_ == NullType::singleton()) { return 0; } - return target_->weakcount_.load(std::memory_order_acquire); + return target_->weakcount_.load(std::memory_order_relaxed); } bool expired() const noexcept { @@ -866,18 +871,22 @@ class weak_intrusive_ptr final { } intrusive_ptr lock() const noexcept { - if (expired()) { + if (target_ == NullType::singleton()) { return intrusive_ptr(); } else { - auto refcount = target_->refcount_.load(std::memory_order_seq_cst); + auto refcount = target_->refcount_.load(std::memory_order_relaxed); do { if (refcount == 0) { // Object already destructed, no strong references left anymore. // Return nullptr. return intrusive_ptr(); } - } while ( - !target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); + } while (!target_->refcount_.compare_exchange_weak( + refcount, + refcount + 1, + std::memory_order_acquire, + std::memory_order_relaxed)); + return intrusive_ptr( target_, raw::DontIncreaseRefcount{}); }