mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PGNCCL] Use recursive mutex in NCCLComm (#138997)
Fixes #138995: [PGNCCL][BUG] mutex acquired in recursive way may deadlock The fix: use `std::recursive_mutex` to replace `std::mutex`. Found and proposed by @dsjohns2. Thanks! Pull Request resolved: https://github.com/pytorch/pytorch/pull/138997 Approved by: https://github.com/dsjohns2
This commit is contained in:
@ -16,7 +16,7 @@
|
||||
namespace c10d {
|
||||
|
||||
ncclComm_t NCCLComm::getNcclComm() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
if (aborted_) {
|
||||
auto commFailureMsg = commFailureReason_ != std::nullopt
|
||||
? c10::str(" Original reason for failure was: ", *commFailureReason_)
|
||||
|
@ -275,6 +275,9 @@ class TORCH_API DebugInfoWriter {
|
||||
|
||||
// RAII wrapper for NCCL communicator
|
||||
class NCCLComm {
|
||||
using MutexType = std::recursive_mutex;
|
||||
using LockType = std::unique_lock<MutexType>;
|
||||
|
||||
public:
|
||||
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
|
||||
|
||||
@ -283,7 +286,7 @@ class NCCLComm {
|
||||
~NCCLComm() noexcept {
|
||||
// Add lock in this destructor, as aborted_ needs to be read after memory
|
||||
// barrier here.
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
if (ncclComm_ && initialized_ && !aborted_) {
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
// Use ncclCommAbort instead of ncclCommDestroy here since
|
||||
@ -371,7 +374,7 @@ class NCCLComm {
|
||||
NCCLComm(NCCLComm&& other) {
|
||||
// Using other's lock, as it reads other's states
|
||||
// Can not use this.mutex_, as this object is being constructed.
|
||||
std::unique_lock<std::mutex> lock(other.mutex_);
|
||||
LockType lock(other.mutex_);
|
||||
std::swap(ncclComm_, other.ncclComm_);
|
||||
std::swap(aborted_, other.aborted_);
|
||||
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||
@ -382,13 +385,13 @@ class NCCLComm {
|
||||
ncclComm_t getNcclComm();
|
||||
|
||||
std::optional<std::string> getNcclCommFailureReason() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
return commFailureReason_;
|
||||
}
|
||||
|
||||
void ncclCommAbort(
|
||||
std::optional<std::string> commFailureReason = std::nullopt) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
if (aborted_ && !initialized_) {
|
||||
// Should not abort twice.
|
||||
@ -436,12 +439,12 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
bool isInitialized() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
return initialized_;
|
||||
}
|
||||
|
||||
bool isAborted() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
return aborted_;
|
||||
}
|
||||
|
||||
@ -450,7 +453,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
ncclResult_t checkForNcclError() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
if (ncclAsyncErr_ != ncclSuccess) {
|
||||
return ncclAsyncErr_;
|
||||
@ -465,7 +468,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
ncclResult_t registerSegment(void* ptr, size_t size) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
// We register only segments from cache allocator
|
||||
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
|
||||
@ -498,7 +501,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
ncclResult_t deregisterSegment(void* ptr) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
LockType lock(mutex_);
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
TORCH_CHECK(
|
||||
registeredSegmentHandles_.count(ptr) == 1,
|
||||
@ -538,7 +541,7 @@ class NCCLComm {
|
||||
bool aborted_{false};
|
||||
uint64_t ncclCommSplitCounter_{0};
|
||||
ncclResult_t ncclAsyncErr_{ncclSuccess};
|
||||
mutable std::mutex mutex_;
|
||||
mutable MutexType mutex_;
|
||||
// Rank that this communicator corresponds to.
|
||||
int rank_{};
|
||||
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
|
||||
|
Reference in New Issue
Block a user