[PGNCCL] Use recursive mutex in NCCLComm (#138997)

Fixes #138995: [PGNCCL][BUG] mutex acquired in recursive way may deadlock

The fix: use `std::recursive_mutex` to replace `std::mutex`.

Found and proposed by @dsjohns2. Thanks!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138997
Approved by: https://github.com/dsjohns2
This commit is contained in:
Ke Wen
2024-10-26 13:32:40 -07:00
committed by PyTorch MergeBot
parent 4681539f42
commit 1152726feb
2 changed files with 14 additions and 11 deletions

View File

@ -16,7 +16,7 @@
namespace c10d {
ncclComm_t NCCLComm::getNcclComm() {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
if (aborted_) {
auto commFailureMsg = commFailureReason_ != std::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_)

View File

@ -275,6 +275,9 @@ class TORCH_API DebugInfoWriter {
// RAII wrapper for NCCL communicator
class NCCLComm {
using MutexType = std::recursive_mutex;
using LockType = std::unique_lock<MutexType>;
public:
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
@ -283,7 +286,7 @@ class NCCLComm {
~NCCLComm() noexcept {
// Add lock in this destructor, as aborted_ needs to be read after memory
// barrier here.
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
if (ncclComm_ && initialized_ && !aborted_) {
#ifdef ENABLE_NCCL_ERROR_CHECKING
// Use ncclCommAbort instead of ncclCommDestroy here since
@ -371,7 +374,7 @@ class NCCLComm {
NCCLComm(NCCLComm&& other) {
// Using other's lock, as it reads other's states
// Can not use this.mutex_, as this object is being constructed.
std::unique_lock<std::mutex> lock(other.mutex_);
LockType lock(other.mutex_);
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
@ -382,13 +385,13 @@ class NCCLComm {
ncclComm_t getNcclComm();
std::optional<std::string> getNcclCommFailureReason() const {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
return commFailureReason_;
}
void ncclCommAbort(
std::optional<std::string> commFailureReason = std::nullopt) {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_ && !initialized_) {
// Should not abort twice.
@ -436,12 +439,12 @@ class NCCLComm {
}
bool isInitialized() const {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
return initialized_;
}
bool isAborted() const {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
return aborted_;
}
@ -450,7 +453,7 @@ class NCCLComm {
}
ncclResult_t checkForNcclError() {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
@ -465,7 +468,7 @@ class NCCLComm {
}
ncclResult_t registerSegment(void* ptr, size_t size) {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
// We register only segments from cache allocator
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
@ -498,7 +501,7 @@ class NCCLComm {
}
ncclResult_t deregisterSegment(void* ptr) {
std::unique_lock<std::mutex> lock(mutex_);
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 1,
@ -538,7 +541,7 @@ class NCCLComm {
bool aborted_{false};
uint64_t ncclCommSplitCounter_{0};
ncclResult_t ncclAsyncErr_{ncclSuccess};
mutable std::mutex mutex_;
mutable MutexType mutex_;
// Rank that this communicator corresponds to.
int rank_{};
// Optional reason for communicator failure, provided by ProcessGroupNCCL for