Revert "c10d/logging: add C10D_LOCK_GUARD (#134131)"

This reverts commit f33bcbe5fd67e6b18be259ad2f0dc11c74157075.

Reverted https://github.com/pytorch/pytorch/pull/134131 on behalf of https://github.com/kit1980 due to See D61985186 ([comment](https://github.com/pytorch/pytorch/pull/134131#issuecomment-2327556381))
This commit is contained in:
PyTorch MergeBot
2024-09-03 22:35:14 +00:00
parent 2fd36086bc
commit c044deb9ce
14 changed files with 59 additions and 183 deletions

View File

@ -496,7 +496,6 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/GroupRegistry.cpp",
"torch/csrc/distributed/c10d/LockGuard.cpp",
"torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp",

View File

@ -927,7 +927,6 @@ elseif(USE_CUDA)
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only)
if(USE_CUFILE)
target_link_libraries(torch_cuda PRIVATE torch::cufile)
@ -1327,7 +1326,6 @@ if(USE_ROCM)
${ROCM_SOURCE_DIR}/rocblas/include
${ROCM_SOURCE_DIR}/hipsparse/include
)
target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only)
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
endif()

View File

@ -10,14 +10,12 @@ function(c10d_add_test test_src)
add_executable(${test_name} "${test_src}")
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>)
target_link_libraries(${test_name} ${ARGN})
target_link_libraries(${test_name} fmt::fmt-header-only)
if(NOT WIN32)
target_link_libraries(${test_name} pthread)
endif()
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
endfunction()
c10d_add_test(LoggingTest.cpp torch_cpu gtest_main)
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)

View File

@ -1,54 +0,0 @@
#include <gtest/gtest.h>
#include <future>
#include <thread>
#include <c10/util/Logging.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
TEST(LockGuard, basic) {
std::timed_mutex mutex;
{
C10D_LOCK_GUARD(lock, mutex);
// already locked
ASSERT_FALSE(mutex.try_lock());
}
ASSERT_TRUE(mutex.try_lock());
mutex.unlock();
}
TEST(LockGuard, logging) {
// set log level to INFO
FLAGS_caffe2_log_level = 0;
std::timed_mutex mutex;
mutex.lock();
auto loggingThread = std::async(std::launch::async, [&]() {
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock};
::c10d::detail::lockWithLogging(
name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__);
});
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10);
while (true) {
ASSERT_LT(std::chrono::system_clock::now(), deadline);
testing::internal::CaptureStderr();
std::this_thread::sleep_for(std::chrono::milliseconds(20));
std::string output = testing::internal::GetCapturedStderr();
if (output.find("my lock: waiting for lock for 10ms") !=
std::string::npos) {
break;
}
}
mutex.unlock();
loggingThread.get();
}

View File

@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
hasMonitorThreadCaughtError_(false) {}
std::timed_mutex& getWatchdogMutex() {
std::mutex& getWatchdogMutex() {
return workMetaListMutex_;
}
@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
work = pg.allreduce(tensors_);
{
// Now run all reduce with errors.
std::lock_guard<std::timed_mutex> lock(pg.getWatchdogMutex());
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
LOG(INFO) << "Lock watchdog thread.";
// Wait long enough before monitor thread throws exceptions.
std::this_thread::sleep_for(

View File

@ -1,29 +0,0 @@
// Copyright (c) Meta Platforms, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d::detail {
void lockWithLogging(
std::unique_lock<std::timed_mutex>& lock,
std::chrono::milliseconds log_interval,
const char* desc,
const char* file,
int line) {
while (!lock.try_lock_for(log_interval)) {
C10D_WARNING(
"{}:{} {}: waiting for lock for {}ms",
file,
line,
desc,
log_interval.count());
}
}
} // namespace c10d::detail

View File

@ -1,32 +0,0 @@
// Copyright (c) Meta Platforms, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <chrono>
#include <mutex>
#include <c10/macros/Export.h>
namespace c10d::detail {
// logWithLogging is a wrapper around std::unique_lock<std::timed_mutex>
// that automatically logs if the lock cannot be acquired within a given
// timeout.
TORCH_API void lockWithLogging(
std::unique_lock<std::timed_mutex>& lock,
std::chrono::milliseconds log_interval,
const char* desc,
const char* file,
int line);
} // namespace c10d::detail
// TODO: use std::source_location() when we can use C++20
#define C10D_LOCK_GUARD(name, mutex) \
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock}; \
::c10d::detail::lockWithLogging( \
name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__)

View File

@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10;
namespace c10d {
ncclComm_t NCCLComm::getNcclComm() {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
if (aborted_) {
auto commFailureMsg = commFailureReason_ != std::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_)
@ -391,7 +391,7 @@ std::optional<size_t> NCCLTraceBuffer::record(
}
auto traceback =
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
C10D_LOCK_GUARD(guard, mutex_);
std::lock_guard<std::mutex> guard(mutex_);
auto te = Entry{
id_,
@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks(
if (!enabled_) {
return;
}
C10D_LOCK_GUARD(guard, mutex_);
std::lock_guard<std::mutex> guard(mutex_);
pg_name_to_ranks_[pg_name] = ranks;
}
@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) {
}
std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
C10D_LOCK_GUARD(guard, mutex_);
std::lock_guard<std::mutex> guard(mutex_);
std::vector<Entry> result;
result.reserve(entries_.size());
result.insert(result.end(), entries_.begin() + next_, entries_.end());
@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id(
Event* endEvent = nullptr;
std::optional<float> duration = std::nullopt;
C10D_LOCK_GUARD(guard, mutex_);
std::unique_lock<std::mutex> guard(mutex_);
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {

View File

@ -13,7 +13,6 @@
#include <ATen/cuda/CUDAEvent.h>
#include <c10/util/Exception.h>
#include <nccl.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <optional>
@ -271,7 +270,7 @@ class NCCLComm {
~NCCLComm() noexcept {
// Add lock in this destructor, as aborted_ needs to be read after memory
// barrier here.
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
if (ncclComm_ && initialized_ && !aborted_) {
#ifdef ENABLE_NCCL_ERROR_CHECKING
// Use ncclCommAbort instead of ncclCommDestroy here since
@ -363,7 +362,7 @@ class NCCLComm {
NCCLComm(NCCLComm&& other) {
// Using other's lock, as it reads other's states
// Can not use this.mutex_, as this object is being constructed.
C10D_LOCK_GUARD(lock, other.mutex_);
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
@ -373,13 +372,13 @@ class NCCLComm {
ncclComm_t getNcclComm();
std::optional<std::string> getNcclCommFailureReason() const {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
return commFailureReason_;
}
void ncclCommAbort(
std::optional<std::string> commFailureReason = std::nullopt) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_ && !initialized_) {
// Should not abort twice.
@ -427,7 +426,7 @@ class NCCLComm {
}
bool isAborted() const {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
return aborted_;
}
@ -436,7 +435,7 @@ class NCCLComm {
}
ncclResult_t checkForNcclError() {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
@ -451,7 +450,7 @@ class NCCLComm {
}
ncclResult_t registerSegment(void* ptr, size_t size) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
// We register only segments from cache allocator
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
@ -482,7 +481,7 @@ class NCCLComm {
}
ncclResult_t deregisterSegment(void* ptr) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 1,
@ -519,7 +518,7 @@ class NCCLComm {
bool aborted_;
uint64_t ncclCommSplitCounter_{0};
ncclResult_t ncclAsyncErr_;
mutable std::timed_mutex mutex_;
mutable std::mutex mutex_;
// Rank that this communicator corresponds to.
int rank_;
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
@ -638,7 +637,7 @@ struct NCCLTraceBuffer {
bool enabled_ = false;
bool capture_cpp_stack_ = false;
std::timed_mutex mutex_;
std::mutex mutex_;
std::vector<Entry> entries_;
size_t max_entries_ = 0;
size_t next_ = 0;

View File

@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const {
}
int ProcessGroupGloo::RecvWork::sourceRank() const {
std::lock_guard<std::timed_mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex_);
return srcRank_;
}

View File

@ -23,7 +23,6 @@
#include <c10/util/irange.h>
#include <c10/util/thread_name.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/NanCheck.hpp>
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
@ -303,7 +302,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
// hooks are called outside the scope of any PG, thus we need traverse
// communicators in all PGs.
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
static std::timed_mutex ncclCommDevIdxMapMutex;
static std::mutex ncclCommDevIdxMapMutex;
static bool allocatorHooksAttached = false;
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
@ -316,7 +315,7 @@ void cacheAllocatorRegisterHook(
return;
}
C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
for (auto& it : ncclCommDevIdxMap) {
auto& ncclComm = it.first;
auto& devIdx = it.second;
@ -334,7 +333,7 @@ void cacheAllocatorDeregisterHook(
return;
}
C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
for (auto& it : ncclCommDevIdxMap) {
auto& ncclComm = it.first;
auto& devIdx = it.second;
@ -537,7 +536,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
}
auto exception_ptr = checkForNCCLErrors();
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
if (exception_) {
LOG(ERROR) << logPrefix() << "Collective " << *this
@ -553,7 +552,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
void ProcessGroupNCCL::WorkNCCL::setException(
std::exception_ptr exception_ptr) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
}
@ -762,12 +761,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {}
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
bool timing) {
auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
C10D_LOCK_GUARD(lock, this->cacheMutex_);
std::lock_guard<std::mutex> lock(this->cacheMutex_);
this->eventsArray_[timing ? 1 : 0].push_back(event);
};
at::cuda::CUDAEvent* event = nullptr;
{
C10D_LOCK_GUARD(lock, cacheMutex_);
std::lock_guard<std::mutex> lock(cacheMutex_);
auto events = eventsArray_[timing ? 1 : 0];
if (!events.empty()) {
event = events.back();
@ -1072,9 +1071,8 @@ void ProcessGroupNCCL::waitForPendingWorks() {
while (true) {
{
std::lock(workMetaListMutex_, completedWorkListMutex_);
std::lock_guard<std::timed_mutex> lockWork(
workMetaListMutex_, std::adopt_lock);
std::lock_guard<std::timed_mutex> lockHook(
std::lock_guard<std::mutex> lockWork(workMetaListMutex_, std::adopt_lock);
std::lock_guard<std::mutex> lockHook(
completedWorkListMutex_, std::adopt_lock);
if (workMetaList_.empty() && completedWorkList_.empty()) {
@ -1187,7 +1185,7 @@ bool ProcessGroupNCCL::abort(std::optional<std::string> abortReason) {
}
ncclCommDevIdxMapMutex.unlock();
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
abortCommsFromMap(devNCCLCommMap_, abortReason);
abortCommsFromMap(inInitializationCommMap_, abortReason);
return true;
@ -1256,8 +1254,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() {
// Serialize all calls to this function to avoid corrupting data, but allow
// multiple calls in one runtime. User is responsible for preserving the
// output file from an earlier call before a later call overwrites it.
static std::timed_mutex writeDebugInfoMutex;
C10D_LOCK_GUARD(lock, writeDebugInfoMutex);
static std::mutex writeDebugInfoMutex;
std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
if (ncclTraceBufferSize_ > 0) {
// We dump nccl trace into local disk by default and users can register
@ -1336,7 +1334,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// This won't have any lock since this lock is only used here.
// Please be aware that mutex `monitorMutex_` should not be used
// somewhere else to avoid the deadlock.
C10D_LOCK_GUARD(lock, monitorMutex_);
std::unique_lock<std::mutex> lock(monitorMutex_);
if (monitorWakeUpCV_.wait_for(
lock, std::chrono::milliseconds(monitorPollInterval), [&] {
return terminateHeartbeatMonitorThread_.load();
@ -1661,7 +1659,7 @@ const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
void ProcessGroupNCCL::addEphemeralTimeout(
const std::chrono::milliseconds& timeout) {
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
ephemeralTimeoutActive_ += timeout;
}
@ -1684,7 +1682,7 @@ void ProcessGroupNCCL::watchdogHandler() {
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
while (!done || !terminateProcessGroup_.load()) {
C10D_LOCK_GUARD(lock, workMetaListMutex_);
std::unique_lock<std::mutex> lock(workMetaListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True.
workMetaListCV_.wait_for(
@ -1856,7 +1854,7 @@ void ProcessGroupNCCL::watchdogHandler() {
if (work.isCompleted()) {
{
// Reset the timeout and first work if the work is completed.
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
if (work.ownedEphermeralTimeout_.count() > 0) {
ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
@ -1871,7 +1869,7 @@ void ProcessGroupNCCL::watchdogHandler() {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
{
C10D_LOCK_GUARD(lock, completedWorkListMutex_);
const std::lock_guard<std::mutex> lock(completedWorkListMutex_);
completedWorkList_.splice(
completedWorkList_.end(), workMetaList_, it++);
}
@ -1899,7 +1897,7 @@ void ProcessGroupNCCL::runHookLoop() {
bool done = false;
while (!done || !terminateProcessGroup_.load()) {
C10D_LOCK_GUARD(lock, completedWorkListMutex_);
std::unique_lock<std::mutex> lock(completedWorkListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True.
completedWorkListCV_.wait_for(
@ -2072,7 +2070,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
}
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
TORCH_INTERNAL_ASSERT(
false,
@ -2120,7 +2118,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
usedDeviceIdxs_.insert(device.index());
{
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
// Reuse the cached communicator if there is one.
return devNCCLCommMap_[deviceKey];
@ -2194,7 +2192,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
options_->split_color != 0,
"Must specify a non-zero color when splitting");
// Find a valid, healthy communicator to split from if possible.
C10D_LOCK_GUARD(lock, options_->split_from->mutex_);
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
auto& other_comms = options_->split_from->devNCCLCommMap_;
auto dit = other_comms.find(getKeyFromDevice(device));
if (dit != other_comms.end()) {
@ -2248,7 +2246,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
options_->is_high_priority_stream || force_high);
{
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
inInitializationCommMap_.emplace(deviceKey, ncclComm);
}
@ -2498,7 +2496,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
std::chrono::milliseconds timeout = option->timeout;
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
if (ephemeralTimeoutActive_.count() > 0) {
timeout += ephemeralTimeoutActive_;
}
@ -2511,7 +2509,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
void ProcessGroupNCCL::workEnqueue(
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
if (!terminateProcessGroup_.load()) {
C10D_LOCK_GUARD(lock, workMetaListMutex_);
std::lock_guard<std::mutex> lock(workMetaListMutex_);
// Avoid view tensors to be processed in cleanup thread.
// View tensors' destruction invokes autograd_meta, which
// needs to be destructed in user thread. Otherwise will

View File

@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
static CUDAEventCache& get();
private:
std::timed_mutex cacheMutex_;
std::mutex cacheMutex_;
// NOTE: We intentionaly store raw pointers so that
// we do not attempt to destroy the event objects on process exit,
// because cuda may be gone.
@ -920,7 +920,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// ephemeralTimeoutActive_/ephemeralTimeoutInflight_.
// TODO(fduwjj): We need to have an audit on all mutexes we are adding here.
// And consolidate them if possible.
std::timed_mutex mtxTimeoutExtension_;
std::mutex mtxTimeoutExtension_;
// The ephemeral timeout added on top of existing timeout for works issued
// before first work finishes.
@ -980,7 +980,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
inInitializationCommMap_;
// Mutex to guard maps like devNCCLCommMap_.
std::timed_mutex mutex_;
std::mutex mutex_;
// Heartbeat of watchdog thread.
std::atomic_uint64_t heartbeat_;
@ -1041,18 +1041,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
static std::atomic<bool> shouldDump_;
// Mutex to Guard workMetaList_
std::timed_mutex workMetaListMutex_;
std::mutex workMetaListMutex_;
// Mutex to Guard monitorWakeUpCV_
std::timed_mutex monitorMutex_;
std::mutex monitorMutex_;
bool writeDebugInfo_ = false;
// Condition Variable for watchdog thread sleep
std::condition_variable_any workMetaListCV_;
std::condition_variable workMetaListCV_;
// Condition Variable for monitor thread to wake up early
std::condition_variable_any monitorWakeUpCV_;
std::condition_variable monitorWakeUpCV_;
// Vector to Store WorkNCCL pointers
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
@ -1060,10 +1060,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
// Mutex to Guard workMetaList_
std::timed_mutex completedWorkListMutex_;
std::mutex completedWorkListMutex_;
// Condition Variable for watchdog thread sleep
std::condition_variable_any completedWorkListCV_;
std::condition_variable completedWorkListCV_;
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;

View File

@ -1,6 +1,5 @@
#include <ATen/ThreadLocalState.h>
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <utility>
@ -46,17 +45,17 @@ OpType Work::retrieveOpType() const {
Work::~Work() = default;
bool Work::isCompleted() {
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
return completed_;
}
bool Work::isSuccess() const {
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
return !exception_;
}
std::exception_ptr Work::exception() const {
C10D_LOCK_GUARD(lock, mutex_);
std::lock_guard<std::mutex> lock(mutex_);
return exception_;
}
@ -74,7 +73,7 @@ std::vector<at::Tensor> Work::result() {
void Work::synchronize() {}
bool Work::wait(std::chrono::milliseconds timeout) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
if (timeout == kNoTimeout) {
// This waits without a timeout.
cv_.wait(lock, [&] { return completed_; });
@ -104,7 +103,7 @@ c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
}
void Work::finish(std::exception_ptr exception) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
completed_ = true;
exception_ = std::move(exception);
if (recordFunctionEndCallback_) {
@ -116,7 +115,7 @@ void Work::finish(std::exception_ptr exception) {
}
void Work::finishAndThrow(std::exception_ptr exception) {
C10D_LOCK_GUARD(lock, mutex_);
std::unique_lock<std::mutex> lock(mutex_);
completed_ = true;
exception_ = std::move(exception);
if (recordFunctionEndCallback_) {

View File

@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder {
// provided by the user.
void finishAndThrow(std::exception_ptr exception);
mutable std::timed_mutex mutex_;
std::condition_variable_any cv_;
mutable std::mutex mutex_;
std::condition_variable cv_;
bool completed_ = false;
std::exception_ptr exception_;