c10d/logging: add C10D_LOCK_GUARD (#134131)

This adds logs if we can't acquire locks in NCCLUtils and ProcessGroupNCCL for 30s.

This is motivated by some deadlocks were seeing and it's unclear if it's in NCCL or on the PyTorch side of things.

This required replacing most `std::mutex` with `std::timed_mutex` and `std::condition_variable_any` as appropriate.

Test plan:

existing CI for regressions

will add unit tests on `C10D_LOCK_GUARD`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134131
Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2024-08-28 01:40:42 +00:00
committed by PyTorch MergeBot
parent c45ca8092d
commit f33bcbe5fd
14 changed files with 183 additions and 59 deletions

View File

@ -495,6 +495,7 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/GroupRegistry.cpp", "torch/csrc/distributed/c10d/GroupRegistry.cpp",
"torch/csrc/distributed/c10d/LockGuard.cpp",
"torch/csrc/distributed/c10d/Ops.cpp", "torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp", "torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp", "torch/csrc/distributed/c10d/PrefixStore.cpp",

View File

@ -927,6 +927,7 @@ elseif(USE_CUDA)
set(CUDA_LINK_LIBRARIES_KEYWORD) set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
target_compile_definitions(torch_cuda PRIVATE USE_CUDA) target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only)
if(USE_CUFILE) if(USE_CUFILE)
target_link_libraries(torch_cuda PRIVATE torch::cufile) target_link_libraries(torch_cuda PRIVATE torch::cufile)
@ -1326,6 +1327,7 @@ if(USE_ROCM)
${ROCM_SOURCE_DIR}/rocblas/include ${ROCM_SOURCE_DIR}/rocblas/include
${ROCM_SOURCE_DIR}/hipsparse/include ${ROCM_SOURCE_DIR}/hipsparse/include
) )
target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only)
if(USE_FLASH_ATTENTION) if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
endif() endif()

View File

@ -10,12 +10,14 @@ function(c10d_add_test test_src)
add_executable(${test_name} "${test_src}") add_executable(${test_name} "${test_src}")
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>) target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>)
target_link_libraries(${test_name} ${ARGN}) target_link_libraries(${test_name} ${ARGN})
target_link_libraries(${test_name} fmt::fmt-header-only)
if(NOT WIN32) if(NOT WIN32)
target_link_libraries(${test_name} pthread) target_link_libraries(${test_name} pthread)
endif() endif()
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>) add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
endfunction() endfunction()
c10d_add_test(LoggingTest.cpp torch_cpu gtest_main)
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main) c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)

View File

@ -0,0 +1,54 @@
#include <gtest/gtest.h>
#include <future>
#include <thread>
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
TEST(LockGuard, basic) {
std::timed_mutex mutex;
{
C10D_LOCK_GUARD(lock, mutex);
// already locked
ASSERT_FALSE(mutex.try_lock());
}
ASSERT_TRUE(mutex.try_lock());
mutex.unlock();
}
TEST(LockGuard, logging) {
// set log level to INFO
FLAGS_caffe2_log_level = 0;
std::timed_mutex mutex;
mutex.lock();
auto loggingThread = std::async(std::launch::async, [&]() {
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock};
::c10d::detail::lockWithLogging(
name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__);
});
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10);
while (true) {
ASSERT_LT(std::chrono::system_clock::now(), deadline);
testing::internal::CaptureStderr();
std::this_thread::sleep_for(std::chrono::milliseconds(20));
std::string output = testing::internal::GetCapturedStderr();
if (output.find("my lock: waiting for lock for 10ms") !=
std::string::npos) {
break;
}
}
mutex.unlock();
loggingThread.get();
}

View File

@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
hasMonitorThreadCaughtError_(false) {} hasMonitorThreadCaughtError_(false) {}
std::mutex& getWatchdogMutex() { std::timed_mutex& getWatchdogMutex() {
return workMetaListMutex_; return workMetaListMutex_;
} }
@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
work = pg.allreduce(tensors_); work = pg.allreduce(tensors_);
{ {
// Now run all reduce with errors. // Now run all reduce with errors.
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex()); std::lock_guard<std::timed_mutex> lock(pg.getWatchdogMutex());
LOG(INFO) << "Lock watchdog thread."; LOG(INFO) << "Lock watchdog thread.";
// Wait long enough before monitor thread throws exceptions. // Wait long enough before monitor thread throws exceptions.
std::this_thread::sleep_for( std::this_thread::sleep_for(

View File

@ -0,0 +1,29 @@
// Copyright (c) Meta Platforms, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d::detail {
void lockWithLogging(
std::unique_lock<std::timed_mutex>& lock,
std::chrono::milliseconds log_interval,
const char* desc,
const char* file,
int line) {
while (!lock.try_lock_for(log_interval)) {
C10D_WARNING(
"{}:{} {}: waiting for lock for {}ms",
file,
line,
desc,
log_interval.count());
}
}
} // namespace c10d::detail

View File

@ -0,0 +1,32 @@
// Copyright (c) Meta Platforms, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <chrono>
#include <mutex>
#include <c10/macros/Export.h>
namespace c10d::detail {
// logWithLogging is a wrapper around std::unique_lock<std::timed_mutex>
// that automatically logs if the lock cannot be acquired within a given
// timeout.
TORCH_API void lockWithLogging(
std::unique_lock<std::timed_mutex>& lock,
std::chrono::milliseconds log_interval,
const char* desc,
const char* file,
int line);
} // namespace c10d::detail
// TODO: use std::source_location() when we can use C++20
#define C10D_LOCK_GUARD(name, mutex) \
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock}; \
::c10d::detail::lockWithLogging( \
name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__)

View File

@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10;
namespace c10d { namespace c10d {
ncclComm_t NCCLComm::getNcclComm() { ncclComm_t NCCLComm::getNcclComm() {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
if (aborted_) { if (aborted_) {
auto commFailureMsg = commFailureReason_ != std::nullopt auto commFailureMsg = commFailureReason_ != std::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_) ? c10::str(" Original reason for failure was: ", *commFailureReason_)
@ -391,7 +391,7 @@ std::optional<size_t> NCCLTraceBuffer::record(
} }
auto traceback = auto traceback =
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
std::lock_guard<std::mutex> guard(mutex_); C10D_LOCK_GUARD(guard, mutex_);
auto te = Entry{ auto te = Entry{
id_, id_,
@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks(
if (!enabled_) { if (!enabled_) {
return; return;
} }
std::lock_guard<std::mutex> guard(mutex_); C10D_LOCK_GUARD(guard, mutex_);
pg_name_to_ranks_[pg_name] = ranks; pg_name_to_ranks_[pg_name] = ranks;
} }
@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) {
} }
std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() { std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
std::lock_guard<std::mutex> guard(mutex_); C10D_LOCK_GUARD(guard, mutex_);
std::vector<Entry> result; std::vector<Entry> result;
result.reserve(entries_.size()); result.reserve(entries_.size());
result.insert(result.end(), entries_.begin() + next_, entries_.end()); result.insert(result.end(), entries_.begin() + next_, entries_.end());
@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id(
Event* endEvent = nullptr; Event* endEvent = nullptr;
std::optional<float> duration = std::nullopt; std::optional<float> duration = std::nullopt;
std::unique_lock<std::mutex> guard(mutex_); C10D_LOCK_GUARD(guard, mutex_);
Entry* entry = &entries_.at(*id % max_entries_); Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) { if (entry->id_ == *id) {

View File

@ -14,6 +14,7 @@
#include <ATen/cuda/CUDAEvent.h> #include <ATen/cuda/CUDAEvent.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <nccl.h> #include <nccl.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/TraceUtils.h> #include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <optional> #include <optional>
@ -271,7 +272,7 @@ class NCCLComm {
~NCCLComm() noexcept { ~NCCLComm() noexcept {
// Add lock in this destructor, as aborted_ needs to be read after memory // Add lock in this destructor, as aborted_ needs to be read after memory
// barrier here. // barrier here.
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
if (ncclComm_ && initialized_ && !aborted_) { if (ncclComm_ && initialized_ && !aborted_) {
#ifdef ENABLE_NCCL_ERROR_CHECKING #ifdef ENABLE_NCCL_ERROR_CHECKING
// Use ncclCommAbort instead of ncclCommDestroy here since // Use ncclCommAbort instead of ncclCommDestroy here since
@ -363,7 +364,7 @@ class NCCLComm {
NCCLComm(NCCLComm&& other) { NCCLComm(NCCLComm&& other) {
// Using other's lock, as it reads other's states // Using other's lock, as it reads other's states
// Can not use this.mutex_, as this object is being constructed. // Can not use this.mutex_, as this object is being constructed.
std::unique_lock<std::mutex> lock(other.mutex_); C10D_LOCK_GUARD(lock, other.mutex_);
std::swap(ncclComm_, other.ncclComm_); std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_); std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
@ -373,13 +374,13 @@ class NCCLComm {
ncclComm_t getNcclComm(); ncclComm_t getNcclComm();
std::optional<std::string> getNcclCommFailureReason() const { std::optional<std::string> getNcclCommFailureReason() const {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
return commFailureReason_; return commFailureReason_;
} }
void ncclCommAbort( void ncclCommAbort(
std::optional<std::string> commFailureReason = std::nullopt) { std::optional<std::string> commFailureReason = std::nullopt) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING #ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_ && !initialized_) { if (aborted_ && !initialized_) {
// Should not abort twice. // Should not abort twice.
@ -427,7 +428,7 @@ class NCCLComm {
} }
bool isAborted() const { bool isAborted() const {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
return aborted_; return aborted_;
} }
@ -436,7 +437,7 @@ class NCCLComm {
} }
ncclResult_t checkForNcclError() { ncclResult_t checkForNcclError() {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING #ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) { if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_; return ncclAsyncErr_;
@ -451,7 +452,7 @@ class NCCLComm {
} }
ncclResult_t registerSegment(void* ptr, size_t size) { ncclResult_t registerSegment(void* ptr, size_t size) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
#ifdef NCCL_HAS_COMM_REGISTER #ifdef NCCL_HAS_COMM_REGISTER
// We register only segments from cache allocator // We register only segments from cache allocator
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
@ -482,7 +483,7 @@ class NCCLComm {
} }
ncclResult_t deregisterSegment(void* ptr) { ncclResult_t deregisterSegment(void* ptr) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
#ifdef NCCL_HAS_COMM_REGISTER #ifdef NCCL_HAS_COMM_REGISTER
TORCH_CHECK( TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 1, registeredSegmentHandles_.count(ptr) == 1,
@ -519,7 +520,7 @@ class NCCLComm {
bool aborted_; bool aborted_;
uint64_t ncclCommSplitCounter_{0}; uint64_t ncclCommSplitCounter_{0};
ncclResult_t ncclAsyncErr_; ncclResult_t ncclAsyncErr_;
mutable std::mutex mutex_; mutable std::timed_mutex mutex_;
// Rank that this communicator corresponds to. // Rank that this communicator corresponds to.
int rank_; int rank_;
// Optional reason for communicator failure, provided by ProcessGroupNCCL for // Optional reason for communicator failure, provided by ProcessGroupNCCL for
@ -638,7 +639,7 @@ struct NCCLTraceBuffer {
bool enabled_ = false; bool enabled_ = false;
bool capture_cpp_stack_ = false; bool capture_cpp_stack_ = false;
std::mutex mutex_; std::timed_mutex mutex_;
std::vector<Entry> entries_; std::vector<Entry> entries_;
size_t max_entries_ = 0; size_t max_entries_ = 0;
size_t next_ = 0; size_t next_ = 0;

View File

@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const {
} }
int ProcessGroupGloo::RecvWork::sourceRank() const { int ProcessGroupGloo::RecvWork::sourceRank() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::timed_mutex> lock(mutex_);
return srcRank_; return srcRank_;
} }

View File

@ -23,6 +23,7 @@
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/util/thread_name.h> #include <c10/util/thread_name.h>
#include <torch/csrc/cuda/nccl.h> #include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp> #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp> #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp> #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
@ -301,7 +302,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
// hooks are called outside the scope of any PG, thus we need traverse // hooks are called outside the scope of any PG, thus we need traverse
// communicators in all PGs. // communicators in all PGs.
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap; static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
static std::mutex ncclCommDevIdxMapMutex; static std::timed_mutex ncclCommDevIdxMapMutex;
static bool allocatorHooksAttached = false; static bool allocatorHooksAttached = false;
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false); std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
@ -314,7 +315,7 @@ void cacheAllocatorRegisterHook(
return; return;
} }
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex); C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
for (auto& it : ncclCommDevIdxMap) { for (auto& it : ncclCommDevIdxMap) {
auto& ncclComm = it.first; auto& ncclComm = it.first;
auto& devIdx = it.second; auto& devIdx = it.second;
@ -332,7 +333,7 @@ void cacheAllocatorDeregisterHook(
return; return;
} }
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex); C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
for (auto& it : ncclCommDevIdxMap) { for (auto& it : ncclCommDevIdxMap) {
auto& ncclComm = it.first; auto& ncclComm = it.first;
auto& devIdx = it.second; auto& devIdx = it.second;
@ -551,7 +552,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
} }
auto exception_ptr = checkForNCCLErrors(); auto exception_ptr = checkForNCCLErrors();
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
exception_ = exception_ptr; exception_ = exception_ptr;
if (exception_) { if (exception_) {
LOG(ERROR) << logPrefix() << "Collective " << *this LOG(ERROR) << logPrefix() << "Collective " << *this
@ -567,7 +568,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
void ProcessGroupNCCL::WorkNCCL::setException( void ProcessGroupNCCL::WorkNCCL::setException(
std::exception_ptr exception_ptr) { std::exception_ptr exception_ptr) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
exception_ = exception_ptr; exception_ = exception_ptr;
} }
@ -776,12 +777,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {}
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create( std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
bool timing) { bool timing) {
auto deleter = [this, timing](at::cuda::CUDAEvent* event) { auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
std::lock_guard<std::mutex> lock(this->cacheMutex_); C10D_LOCK_GUARD(lock, this->cacheMutex_);
this->eventsArray_[timing ? 1 : 0].push_back(event); this->eventsArray_[timing ? 1 : 0].push_back(event);
}; };
at::cuda::CUDAEvent* event = nullptr; at::cuda::CUDAEvent* event = nullptr;
{ {
std::lock_guard<std::mutex> lock(cacheMutex_); C10D_LOCK_GUARD(lock, cacheMutex_);
auto events = eventsArray_[timing ? 1 : 0]; auto events = eventsArray_[timing ? 1 : 0];
if (!events.empty()) { if (!events.empty()) {
event = events.back(); event = events.back();
@ -1086,8 +1087,9 @@ void ProcessGroupNCCL::waitForPendingWorks() {
while (true) { while (true) {
{ {
std::lock(workMetaListMutex_, completedWorkListMutex_); std::lock(workMetaListMutex_, completedWorkListMutex_);
std::lock_guard<std::mutex> lockWork(workMetaListMutex_, std::adopt_lock); std::lock_guard<std::timed_mutex> lockWork(
std::lock_guard<std::mutex> lockHook( workMetaListMutex_, std::adopt_lock);
std::lock_guard<std::timed_mutex> lockHook(
completedWorkListMutex_, std::adopt_lock); completedWorkListMutex_, std::adopt_lock);
if (workMetaList_.empty() && completedWorkList_.empty()) { if (workMetaList_.empty() && completedWorkList_.empty()) {
@ -1207,7 +1209,7 @@ bool ProcessGroupNCCL::abort(std::optional<std::string> abortReason) {
} }
ncclCommDevIdxMapMutex.unlock(); ncclCommDevIdxMapMutex.unlock();
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
abortCommsFromMap(devNCCLCommMap_, abortReason); abortCommsFromMap(devNCCLCommMap_, abortReason);
abortCommsFromMap(inInitializationCommMap_, abortReason); abortCommsFromMap(inInitializationCommMap_, abortReason);
return true; return true;
@ -1276,8 +1278,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() {
// Serialize all calls to this function to avoid corrupting data, but allow // Serialize all calls to this function to avoid corrupting data, but allow
// multiple calls in one runtime. User is responsible for preserving the // multiple calls in one runtime. User is responsible for preserving the
// output file from an earlier call before a later call overwrites it. // output file from an earlier call before a later call overwrites it.
static std::mutex writeDebugInfoMutex; static std::timed_mutex writeDebugInfoMutex;
std::lock_guard<std::mutex> lock(writeDebugInfoMutex); C10D_LOCK_GUARD(lock, writeDebugInfoMutex);
LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
if (ncclTraceBufferSize_ > 0) { if (ncclTraceBufferSize_ > 0) {
// We dump nccl trace into local disk by default and users can register // We dump nccl trace into local disk by default and users can register
@ -1356,7 +1358,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// This won't have any lock since this lock is only used here. // This won't have any lock since this lock is only used here.
// Please be aware that mutex `monitorMutex_` should not be used // Please be aware that mutex `monitorMutex_` should not be used
// somewhere else to avoid the deadlock. // somewhere else to avoid the deadlock.
std::unique_lock<std::mutex> lock(monitorMutex_); C10D_LOCK_GUARD(lock, monitorMutex_);
if (monitorWakeUpCV_.wait_for( if (monitorWakeUpCV_.wait_for(
lock, std::chrono::milliseconds(monitorPollInterval), [&] { lock, std::chrono::milliseconds(monitorPollInterval), [&] {
return terminateHeartbeatMonitorThread_.load(); return terminateHeartbeatMonitorThread_.load();
@ -1681,7 +1683,7 @@ const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
void ProcessGroupNCCL::addEphemeralTimeout( void ProcessGroupNCCL::addEphemeralTimeout(
const std::chrono::milliseconds& timeout) { const std::chrono::milliseconds& timeout) {
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_); C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
ephemeralTimeoutActive_ += timeout; ephemeralTimeoutActive_ += timeout;
} }
@ -1704,7 +1706,7 @@ void ProcessGroupNCCL::watchdogHandler() {
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList; std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
while (!done || !terminateProcessGroup_.load()) { while (!done || !terminateProcessGroup_.load()) {
std::unique_lock<std::mutex> lock(workMetaListMutex_); C10D_LOCK_GUARD(lock, workMetaListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis // We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True. // milliseconds as long as the atomic is True.
workMetaListCV_.wait_for( workMetaListCV_.wait_for(
@ -1876,7 +1878,7 @@ void ProcessGroupNCCL::watchdogHandler() {
if (work.isCompleted()) { if (work.isCompleted()) {
{ {
// Reset the timeout and first work if the work is completed. // Reset the timeout and first work if the work is completed.
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_); C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
if (work.ownedEphermeralTimeout_.count() > 0) { if (work.ownedEphermeralTimeout_.count() > 0) {
ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
@ -1891,7 +1893,7 @@ void ProcessGroupNCCL::watchdogHandler() {
// Move Work object to completedWorkList_ to be consumed by the hook // Move Work object to completedWorkList_ to be consumed by the hook
// thread // thread
{ {
const std::lock_guard<std::mutex> lock(completedWorkListMutex_); C10D_LOCK_GUARD(lock, completedWorkListMutex_);
completedWorkList_.splice( completedWorkList_.splice(
completedWorkList_.end(), workMetaList_, it++); completedWorkList_.end(), workMetaList_, it++);
} }
@ -1919,7 +1921,7 @@ void ProcessGroupNCCL::runHookLoop() {
bool done = false; bool done = false;
while (!done || !terminateProcessGroup_.load()) { while (!done || !terminateProcessGroup_.load()) {
std::unique_lock<std::mutex> lock(completedWorkListMutex_); C10D_LOCK_GUARD(lock, completedWorkListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis // We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True. // milliseconds as long as the atomic is True.
completedWorkListCV_.wait_for( completedWorkListCV_.wait_for(
@ -2092,7 +2094,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
} }
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
false, false,
@ -2140,7 +2142,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
usedDeviceIdxs_.insert(device.index()); usedDeviceIdxs_.insert(device.index());
{ {
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
// Reuse the cached communicator if there is one. // Reuse the cached communicator if there is one.
return devNCCLCommMap_[deviceKey]; return devNCCLCommMap_[deviceKey];
@ -2214,7 +2216,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
options_->split_color != 0, options_->split_color != 0,
"Must specify a non-zero color when splitting"); "Must specify a non-zero color when splitting");
// Find a valid, healthy communicator to split from if possible. // Find a valid, healthy communicator to split from if possible.
std::lock_guard<std::mutex> lock(options_->split_from->mutex_); C10D_LOCK_GUARD(lock, options_->split_from->mutex_);
auto& other_comms = options_->split_from->devNCCLCommMap_; auto& other_comms = options_->split_from->devNCCLCommMap_;
auto dit = other_comms.find(getKeyFromDevice(device)); auto dit = other_comms.find(getKeyFromDevice(device));
if (dit != other_comms.end()) { if (dit != other_comms.end()) {
@ -2268,7 +2270,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
options_->is_high_priority_stream || force_high); options_->is_high_priority_stream || force_high);
{ {
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
inInitializationCommMap_.emplace(deviceKey, ncclComm); inInitializationCommMap_.emplace(deviceKey, ncclComm);
} }
@ -2518,7 +2520,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work, const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) { const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
std::chrono::milliseconds timeout = option->timeout; std::chrono::milliseconds timeout = option->timeout;
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_); C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
if (ephemeralTimeoutActive_.count() > 0) { if (ephemeralTimeoutActive_.count() > 0) {
timeout += ephemeralTimeoutActive_; timeout += ephemeralTimeoutActive_;
} }
@ -2531,7 +2533,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
void ProcessGroupNCCL::workEnqueue( void ProcessGroupNCCL::workEnqueue(
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) { c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
if (!terminateProcessGroup_.load()) { if (!terminateProcessGroup_.load()) {
std::lock_guard<std::mutex> lock(workMetaListMutex_); C10D_LOCK_GUARD(lock, workMetaListMutex_);
// Avoid view tensors to be processed in cleanup thread. // Avoid view tensors to be processed in cleanup thread.
// View tensors' destruction invokes autograd_meta, which // View tensors' destruction invokes autograd_meta, which
// needs to be destructed in user thread. Otherwise will // needs to be destructed in user thread. Otherwise will

View File

@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
static CUDAEventCache& get(); static CUDAEventCache& get();
private: private:
std::mutex cacheMutex_; std::timed_mutex cacheMutex_;
// NOTE: We intentionaly store raw pointers so that // NOTE: We intentionaly store raw pointers so that
// we do not attempt to destroy the event objects on process exit, // we do not attempt to destroy the event objects on process exit,
// because cuda may be gone. // because cuda may be gone.
@ -920,7 +920,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// ephemeralTimeoutActive_/ephemeralTimeoutInflight_. // ephemeralTimeoutActive_/ephemeralTimeoutInflight_.
// TODO(fduwjj): We need to have an audit on all mutexes we are adding here. // TODO(fduwjj): We need to have an audit on all mutexes we are adding here.
// And consolidate them if possible. // And consolidate them if possible.
std::mutex mtxTimeoutExtension_; std::timed_mutex mtxTimeoutExtension_;
// The ephemeral timeout added on top of existing timeout for works issued // The ephemeral timeout added on top of existing timeout for works issued
// before first work finishes. // before first work finishes.
@ -980,7 +980,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
inInitializationCommMap_; inInitializationCommMap_;
// Mutex to guard maps like devNCCLCommMap_. // Mutex to guard maps like devNCCLCommMap_.
std::mutex mutex_; std::timed_mutex mutex_;
// Heartbeat of watchdog thread. // Heartbeat of watchdog thread.
std::atomic_uint64_t heartbeat_; std::atomic_uint64_t heartbeat_;
@ -1041,18 +1041,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
static std::atomic<bool> shouldDump_; static std::atomic<bool> shouldDump_;
// Mutex to Guard workMetaList_ // Mutex to Guard workMetaList_
std::mutex workMetaListMutex_; std::timed_mutex workMetaListMutex_;
// Mutex to Guard monitorWakeUpCV_ // Mutex to Guard monitorWakeUpCV_
std::mutex monitorMutex_; std::timed_mutex monitorMutex_;
bool writeDebugInfo_ = false; bool writeDebugInfo_ = false;
// Condition Variable for watchdog thread sleep // Condition Variable for watchdog thread sleep
std::condition_variable workMetaListCV_; std::condition_variable_any workMetaListCV_;
// Condition Variable for monitor thread to wake up early // Condition Variable for monitor thread to wake up early
std::condition_variable monitorWakeUpCV_; std::condition_variable_any monitorWakeUpCV_;
// Vector to Store WorkNCCL pointers // Vector to Store WorkNCCL pointers
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_; std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
@ -1060,10 +1060,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_; std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
// Mutex to Guard workMetaList_ // Mutex to Guard workMetaList_
std::mutex completedWorkListMutex_; std::timed_mutex completedWorkListMutex_;
// Condition Variable for watchdog thread sleep // Condition Variable for watchdog thread sleep
std::condition_variable completedWorkListCV_; std::condition_variable_any completedWorkListCV_;
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_; std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;

View File

@ -1,5 +1,6 @@
#include <ATen/ThreadLocalState.h> #include <ATen/ThreadLocalState.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp> #include <torch/csrc/distributed/c10d/Work.hpp>
#include <utility> #include <utility>
@ -45,17 +46,17 @@ OpType Work::retrieveOpType() const {
Work::~Work() = default; Work::~Work() = default;
bool Work::isCompleted() { bool Work::isCompleted() {
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
return completed_; return completed_;
} }
bool Work::isSuccess() const { bool Work::isSuccess() const {
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
return !exception_; return !exception_;
} }
std::exception_ptr Work::exception() const { std::exception_ptr Work::exception() const {
std::lock_guard<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
return exception_; return exception_;
} }
@ -73,7 +74,7 @@ std::vector<at::Tensor> Work::result() {
void Work::synchronize() {} void Work::synchronize() {}
bool Work::wait(std::chrono::milliseconds timeout) { bool Work::wait(std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
if (timeout == kNoTimeout) { if (timeout == kNoTimeout) {
// This waits without a timeout. // This waits without a timeout.
cv_.wait(lock, [&] { return completed_; }); cv_.wait(lock, [&] { return completed_; });
@ -103,7 +104,7 @@ c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
} }
void Work::finish(std::exception_ptr exception) { void Work::finish(std::exception_ptr exception) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
completed_ = true; completed_ = true;
exception_ = std::move(exception); exception_ = std::move(exception);
if (recordFunctionEndCallback_) { if (recordFunctionEndCallback_) {
@ -115,7 +116,7 @@ void Work::finish(std::exception_ptr exception) {
} }
void Work::finishAndThrow(std::exception_ptr exception) { void Work::finishAndThrow(std::exception_ptr exception) {
std::unique_lock<std::mutex> lock(mutex_); C10D_LOCK_GUARD(lock, mutex_);
completed_ = true; completed_ = true;
exception_ = std::move(exception); exception_ = std::move(exception);
if (recordFunctionEndCallback_) { if (recordFunctionEndCallback_) {

View File

@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder {
// provided by the user. // provided by the user.
void finishAndThrow(std::exception_ptr exception); void finishAndThrow(std::exception_ptr exception);
mutable std::mutex mutex_; mutable std::timed_mutex mutex_;
std::condition_variable cv_; std::condition_variable_any cv_;
bool completed_ = false; bool completed_ = false;
std::exception_ptr exception_; std::exception_ptr exception_;