mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "c10d/logging: add C10D_LOCK_GUARD (#134131)"
This reverts commit 4c28a0eb0ba437c1b7db559f63f8bec17bd48f69. Reverted https://github.com/pytorch/pytorch/pull/134131 on behalf of https://github.com/ZainRizvi due to Sorry but this causes formatting errors internally which make it fail to build. See D61759282 ([comment](https://github.com/pytorch/pytorch/pull/134131#issuecomment-2310455878))
This commit is contained in:
@ -927,7 +927,6 @@ elseif(USE_CUDA)
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD)
|
||||
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
|
||||
target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only)
|
||||
|
||||
if(USE_CUFILE)
|
||||
target_link_libraries(torch_cuda PRIVATE torch::cufile)
|
||||
@ -1327,7 +1326,6 @@ if(USE_ROCM)
|
||||
${ROCM_SOURCE_DIR}/rocblas/include
|
||||
${ROCM_SOURCE_DIR}/hipsparse/include
|
||||
)
|
||||
target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
|
||||
endif()
|
||||
|
@ -10,14 +10,12 @@ function(c10d_add_test test_src)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>)
|
||||
target_link_libraries(${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} fmt::fmt-header-only)
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(${test_name} pthread)
|
||||
endif()
|
||||
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
|
||||
endfunction()
|
||||
|
||||
c10d_add_test(LoggingTest.cpp torch_cpu gtest_main)
|
||||
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
|
||||
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
|
||||
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)
|
||||
|
@ -1,50 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <future>
|
||||
#include <thread>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
|
||||
TEST(LockGuard, basic) {
|
||||
std::timed_mutex mutex;
|
||||
|
||||
{
|
||||
C10D_LOCK_GUARD(lock, mutex);
|
||||
|
||||
// already locked
|
||||
ASSERT_FALSE(mutex.try_lock());
|
||||
}
|
||||
|
||||
ASSERT_TRUE(mutex.try_lock());
|
||||
mutex.unlock();
|
||||
}
|
||||
|
||||
TEST(LockGuard, logging) {
|
||||
std::timed_mutex mutex;
|
||||
|
||||
mutex.lock();
|
||||
|
||||
auto loggingThread = std::async(std::launch::async, [&]() {
|
||||
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock};
|
||||
::c10d::detail::lockWithLogging(
|
||||
name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__);
|
||||
});
|
||||
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10);
|
||||
while (true) {
|
||||
ASSERT_LT(std::chrono::system_clock::now(), deadline);
|
||||
|
||||
testing::internal::CaptureStderr();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(20));
|
||||
std::string output = testing::internal::GetCapturedStderr();
|
||||
|
||||
if (output.find("my lock: waiting for lock for 10ms") !=
|
||||
std::string::npos) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
mutex.unlock();
|
||||
|
||||
loggingThread.get();
|
||||
}
|
@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught
|
||||
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
|
||||
hasMonitorThreadCaughtError_(false) {}
|
||||
|
||||
std::timed_mutex& getWatchdogMutex() {
|
||||
std::mutex& getWatchdogMutex() {
|
||||
return workMetaListMutex_;
|
||||
}
|
||||
|
||||
@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
work = pg.allreduce(tensors_);
|
||||
{
|
||||
// Now run all reduce with errors.
|
||||
std::lock_guard<std::timed_mutex> lock(pg.getWatchdogMutex());
|
||||
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
|
||||
LOG(INFO) << "Lock watchdog thread.";
|
||||
// Wait long enough before monitor thread throws exceptions.
|
||||
std::this_thread::sleep_for(
|
||||
|
@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10;
|
||||
namespace c10d {
|
||||
|
||||
ncclComm_t NCCLComm::getNcclComm() {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (aborted_) {
|
||||
auto commFailureMsg = commFailureReason_ != std::nullopt
|
||||
? c10::str(" Original reason for failure was: ", *commFailureReason_)
|
||||
@ -391,7 +391,7 @@ std::optional<size_t> NCCLTraceBuffer::record(
|
||||
}
|
||||
auto traceback =
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
C10D_LOCK_GUARD(guard, mutex_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
auto te = Entry{
|
||||
id_,
|
||||
@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks(
|
||||
if (!enabled_) {
|
||||
return;
|
||||
}
|
||||
C10D_LOCK_GUARD(guard, mutex_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
pg_name_to_ranks_[pg_name] = ranks;
|
||||
}
|
||||
|
||||
@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) {
|
||||
}
|
||||
|
||||
std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
|
||||
C10D_LOCK_GUARD(guard, mutex_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
std::vector<Entry> result;
|
||||
result.reserve(entries_.size());
|
||||
result.insert(result.end(), entries_.begin() + next_, entries_.end());
|
||||
@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id(
|
||||
Event* endEvent = nullptr;
|
||||
std::optional<float> duration = std::nullopt;
|
||||
|
||||
C10D_LOCK_GUARD(guard, mutex_);
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/TraceUtils.h>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
#include <optional>
|
||||
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
@ -271,7 +270,7 @@ class NCCLComm {
|
||||
~NCCLComm() noexcept {
|
||||
// Add lock in this destructor, as aborted_ needs to be read after memory
|
||||
// barrier here.
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (ncclComm_ && initialized_ && !aborted_) {
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
// Use ncclCommAbort instead of ncclCommDestroy here since
|
||||
@ -363,7 +362,7 @@ class NCCLComm {
|
||||
NCCLComm(NCCLComm&& other) {
|
||||
// Using other's lock, as it reads other's states
|
||||
// Can not use this.mutex_, as this object is being constructed.
|
||||
C10D_LOCK_GUARD(lock, other.mutex_);
|
||||
std::unique_lock<std::mutex> lock(other.mutex_);
|
||||
std::swap(ncclComm_, other.ncclComm_);
|
||||
std::swap(aborted_, other.aborted_);
|
||||
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||
@ -373,13 +372,13 @@ class NCCLComm {
|
||||
ncclComm_t getNcclComm();
|
||||
|
||||
std::optional<std::string> getNcclCommFailureReason() const {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return commFailureReason_;
|
||||
}
|
||||
|
||||
void ncclCommAbort(
|
||||
std::optional<std::string> commFailureReason = std::nullopt) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
if (aborted_ && !initialized_) {
|
||||
// Should not abort twice.
|
||||
@ -427,7 +426,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
bool isAborted() const {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return aborted_;
|
||||
}
|
||||
|
||||
@ -436,7 +435,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
ncclResult_t checkForNcclError() {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
if (ncclAsyncErr_ != ncclSuccess) {
|
||||
return ncclAsyncErr_;
|
||||
@ -451,7 +450,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
ncclResult_t registerSegment(void* ptr, size_t size) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
// We register only segments from cache allocator
|
||||
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
|
||||
@ -482,7 +481,7 @@ class NCCLComm {
|
||||
}
|
||||
|
||||
ncclResult_t deregisterSegment(void* ptr) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
TORCH_CHECK(
|
||||
registeredSegmentHandles_.count(ptr) == 1,
|
||||
@ -519,7 +518,7 @@ class NCCLComm {
|
||||
bool aborted_;
|
||||
uint64_t ncclCommSplitCounter_{0};
|
||||
ncclResult_t ncclAsyncErr_;
|
||||
mutable std::timed_mutex mutex_;
|
||||
mutable std::mutex mutex_;
|
||||
// Rank that this communicator corresponds to.
|
||||
int rank_;
|
||||
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
|
||||
@ -638,7 +637,7 @@ struct NCCLTraceBuffer {
|
||||
|
||||
bool enabled_ = false;
|
||||
bool capture_cpp_stack_ = false;
|
||||
std::timed_mutex mutex_;
|
||||
std::mutex mutex_;
|
||||
std::vector<Entry> entries_;
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
|
@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const {
|
||||
}
|
||||
|
||||
int ProcessGroupGloo::RecvWork::sourceRank() const {
|
||||
std::lock_guard<std::timed_mutex> lock(mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return srcRank_;
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,6 @@
|
||||
#include <torch/csrc/distributed/c10d/TraceUtils.h>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logger.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
#include <torch/torch.h>
|
||||
#include <optional>
|
||||
|
||||
@ -302,7 +301,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
|
||||
// hooks are called outside the scope of any PG, thus we need traverse
|
||||
// communicators in all PGs.
|
||||
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
|
||||
static std::timed_mutex ncclCommDevIdxMapMutex;
|
||||
static std::mutex ncclCommDevIdxMapMutex;
|
||||
static bool allocatorHooksAttached = false;
|
||||
|
||||
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
|
||||
@ -315,7 +314,7 @@ void cacheAllocatorRegisterHook(
|
||||
return;
|
||||
}
|
||||
|
||||
C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
for (auto& it : ncclCommDevIdxMap) {
|
||||
auto& ncclComm = it.first;
|
||||
auto& devIdx = it.second;
|
||||
@ -333,7 +332,7 @@ void cacheAllocatorDeregisterHook(
|
||||
return;
|
||||
}
|
||||
|
||||
C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
for (auto& it : ncclCommDevIdxMap) {
|
||||
auto& ncclComm = it.first;
|
||||
auto& devIdx = it.second;
|
||||
@ -552,7 +551,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
|
||||
}
|
||||
|
||||
auto exception_ptr = checkForNCCLErrors();
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
exception_ = exception_ptr;
|
||||
if (exception_) {
|
||||
LOG(ERROR) << logPrefix() << "Collective " << *this
|
||||
@ -568,7 +567,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::setException(
|
||||
std::exception_ptr exception_ptr) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
exception_ = exception_ptr;
|
||||
}
|
||||
|
||||
@ -777,12 +776,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {}
|
||||
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
|
||||
bool timing) {
|
||||
auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
|
||||
C10D_LOCK_GUARD(lock, this->cacheMutex_);
|
||||
std::lock_guard<std::mutex> lock(this->cacheMutex_);
|
||||
this->eventsArray_[timing ? 1 : 0].push_back(event);
|
||||
};
|
||||
at::cuda::CUDAEvent* event = nullptr;
|
||||
{
|
||||
C10D_LOCK_GUARD(lock, cacheMutex_);
|
||||
std::lock_guard<std::mutex> lock(cacheMutex_);
|
||||
auto events = eventsArray_[timing ? 1 : 0];
|
||||
if (!events.empty()) {
|
||||
event = events.back();
|
||||
@ -1087,9 +1086,8 @@ void ProcessGroupNCCL::waitForPendingWorks() {
|
||||
while (true) {
|
||||
{
|
||||
std::lock(workMetaListMutex_, completedWorkListMutex_);
|
||||
std::lock_guard<std::timed_mutex> lockWork(
|
||||
workMetaListMutex_, std::adopt_lock);
|
||||
std::lock_guard<std::timed_mutex> lockHook(
|
||||
std::lock_guard<std::mutex> lockWork(workMetaListMutex_, std::adopt_lock);
|
||||
std::lock_guard<std::mutex> lockHook(
|
||||
completedWorkListMutex_, std::adopt_lock);
|
||||
|
||||
if (workMetaList_.empty() && completedWorkList_.empty()) {
|
||||
@ -1205,7 +1203,7 @@ bool ProcessGroupNCCL::abort(std::optional<std::string> abortReason) {
|
||||
}
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
abortCommsFromMap(devNCCLCommMap_, abortReason);
|
||||
abortCommsFromMap(inInitializationCommMap_, abortReason);
|
||||
return true;
|
||||
@ -1274,8 +1272,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() {
|
||||
// Serialize all calls to this function to avoid corrupting data, but allow
|
||||
// multiple calls in one runtime. User is responsible for preserving the
|
||||
// output file from an earlier call before a later call overwrites it.
|
||||
static std::timed_mutex writeDebugInfoMutex;
|
||||
C10D_LOCK_GUARD(lock, writeDebugInfoMutex);
|
||||
static std::mutex writeDebugInfoMutex;
|
||||
std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
|
||||
LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
|
||||
if (ncclTraceBufferSize_ > 0) {
|
||||
// We dump nccl trace into local disk by default and users can register
|
||||
@ -1354,7 +1352,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// This won't have any lock since this lock is only used here.
|
||||
// Please be aware that mutex `monitorMutex_` should not be used
|
||||
// somewhere else to avoid the deadlock.
|
||||
C10D_LOCK_GUARD(lock, monitorMutex_);
|
||||
std::unique_lock<std::mutex> lock(monitorMutex_);
|
||||
if (monitorWakeUpCV_.wait_for(
|
||||
lock, std::chrono::milliseconds(monitorPollInterval), [&] {
|
||||
return terminateHeartbeatMonitorThread_.load();
|
||||
@ -1679,7 +1677,7 @@ const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
|
||||
|
||||
void ProcessGroupNCCL::addEphemeralTimeout(
|
||||
const std::chrono::milliseconds& timeout) {
|
||||
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
|
||||
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
|
||||
ephemeralTimeoutActive_ += timeout;
|
||||
}
|
||||
|
||||
@ -1702,7 +1700,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
|
||||
|
||||
while (!done || !terminateProcessGroup_.load()) {
|
||||
C10D_LOCK_GUARD(lock, workMetaListMutex_);
|
||||
std::unique_lock<std::mutex> lock(workMetaListMutex_);
|
||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||
// milliseconds as long as the atomic is True.
|
||||
workMetaListCV_.wait_for(
|
||||
@ -1874,7 +1872,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
if (work.isCompleted()) {
|
||||
{
|
||||
// Reset the timeout and first work if the work is completed.
|
||||
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
|
||||
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
|
||||
if (work.ownedEphermeralTimeout_.count() > 0) {
|
||||
ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
|
||||
ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
|
||||
@ -1889,7 +1887,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
{
|
||||
C10D_LOCK_GUARD(lock, completedWorkListMutex_);
|
||||
const std::lock_guard<std::mutex> lock(completedWorkListMutex_);
|
||||
completedWorkList_.splice(
|
||||
completedWorkList_.end(), workMetaList_, it++);
|
||||
}
|
||||
@ -1917,7 +1915,7 @@ void ProcessGroupNCCL::runHookLoop() {
|
||||
|
||||
bool done = false;
|
||||
while (!done || !terminateProcessGroup_.load()) {
|
||||
C10D_LOCK_GUARD(lock, completedWorkListMutex_);
|
||||
std::unique_lock<std::mutex> lock(completedWorkListMutex_);
|
||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||
// milliseconds as long as the atomic is True.
|
||||
completedWorkListCV_.wait_for(
|
||||
@ -2090,7 +2088,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
@ -2138,7 +2136,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
||||
usedDeviceIdxs_.insert(device.index());
|
||||
|
||||
{
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
|
||||
// Reuse the cached communicator if there is one.
|
||||
return devNCCLCommMap_[deviceKey];
|
||||
@ -2214,7 +2212,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
||||
options_->split_color != 0,
|
||||
"Must specify a non-zero color when splitting");
|
||||
// Find a valid, healthy communicator to split from if possible.
|
||||
C10D_LOCK_GUARD(lock, options_->split_from->mutex_);
|
||||
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
|
||||
auto& other_comms = options_->split_from->devNCCLCommMap_;
|
||||
auto dit = other_comms.find(getKeyFromDevice(device));
|
||||
if (dit != other_comms.end()) {
|
||||
@ -2268,7 +2266,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
||||
options_->is_high_priority_stream || force_high);
|
||||
|
||||
{
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
inInitializationCommMap_.emplace(deviceKey, ncclComm);
|
||||
}
|
||||
|
||||
@ -2518,7 +2516,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
|
||||
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
|
||||
const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
|
||||
std::chrono::milliseconds timeout = option->timeout;
|
||||
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
|
||||
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
|
||||
if (ephemeralTimeoutActive_.count() > 0) {
|
||||
timeout += ephemeralTimeoutActive_;
|
||||
}
|
||||
@ -2531,7 +2529,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
|
||||
void ProcessGroupNCCL::workEnqueue(
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
|
||||
if (!terminateProcessGroup_.load()) {
|
||||
C10D_LOCK_GUARD(lock, workMetaListMutex_);
|
||||
std::lock_guard<std::mutex> lock(workMetaListMutex_);
|
||||
// Avoid view tensors to be processed in cleanup thread.
|
||||
// View tensors' destruction invokes autograd_meta, which
|
||||
// needs to be destructed in user thread. Otherwise will
|
||||
|
@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
static CUDAEventCache& get();
|
||||
|
||||
private:
|
||||
std::timed_mutex cacheMutex_;
|
||||
std::mutex cacheMutex_;
|
||||
// NOTE: We intentionaly store raw pointers so that
|
||||
// we do not attempt to destroy the event objects on process exit,
|
||||
// because cuda may be gone.
|
||||
@ -918,7 +918,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// ephemeralTimeoutActive_/ephemeralTimeoutInflight_.
|
||||
// TODO(fduwjj): We need to have an audit on all mutexes we are adding here.
|
||||
// And consolidate them if possible.
|
||||
std::timed_mutex mtxTimeoutExtension_;
|
||||
std::mutex mtxTimeoutExtension_;
|
||||
|
||||
// The ephemeral timeout added on top of existing timeout for works issued
|
||||
// before first work finishes.
|
||||
@ -978,7 +978,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
inInitializationCommMap_;
|
||||
|
||||
// Mutex to guard maps like devNCCLCommMap_.
|
||||
std::timed_mutex mutex_;
|
||||
std::mutex mutex_;
|
||||
|
||||
// Heartbeat of watchdog thread.
|
||||
std::atomic_uint64_t heartbeat_;
|
||||
@ -1039,18 +1039,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
static std::atomic<bool> shouldDump_;
|
||||
|
||||
// Mutex to Guard workMetaList_
|
||||
std::timed_mutex workMetaListMutex_;
|
||||
std::mutex workMetaListMutex_;
|
||||
|
||||
// Mutex to Guard monitorWakeUpCV_
|
||||
std::timed_mutex monitorMutex_;
|
||||
std::mutex monitorMutex_;
|
||||
|
||||
bool writeDebugInfo_ = false;
|
||||
|
||||
// Condition Variable for watchdog thread sleep
|
||||
std::condition_variable_any workMetaListCV_;
|
||||
std::condition_variable workMetaListCV_;
|
||||
|
||||
// Condition Variable for monitor thread to wake up early
|
||||
std::condition_variable_any monitorWakeUpCV_;
|
||||
std::condition_variable monitorWakeUpCV_;
|
||||
|
||||
// Vector to Store WorkNCCL pointers
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
|
||||
@ -1058,10 +1058,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
|
||||
|
||||
// Mutex to Guard workMetaList_
|
||||
std::timed_mutex completedWorkListMutex_;
|
||||
std::mutex completedWorkListMutex_;
|
||||
|
||||
// Condition Variable for watchdog thread sleep
|
||||
std::condition_variable_any completedWorkListCV_;
|
||||
std::condition_variable completedWorkListCV_;
|
||||
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
#include <utility>
|
||||
|
||||
namespace c10d {
|
||||
@ -46,17 +45,17 @@ OpType Work::retrieveOpType() const {
|
||||
Work::~Work() = default;
|
||||
|
||||
bool Work::isCompleted() {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return completed_;
|
||||
}
|
||||
|
||||
bool Work::isSuccess() const {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return !exception_;
|
||||
}
|
||||
|
||||
std::exception_ptr Work::exception() const {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return exception_;
|
||||
}
|
||||
|
||||
@ -74,7 +73,7 @@ std::vector<at::Tensor> Work::result() {
|
||||
void Work::synchronize() {}
|
||||
|
||||
bool Work::wait(std::chrono::milliseconds timeout) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (timeout == kNoTimeout) {
|
||||
// This waits without a timeout.
|
||||
cv_.wait(lock, [&] { return completed_; });
|
||||
@ -104,7 +103,7 @@ c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
|
||||
}
|
||||
|
||||
void Work::finish(std::exception_ptr exception) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
completed_ = true;
|
||||
exception_ = std::move(exception);
|
||||
if (recordFunctionEndCallback_) {
|
||||
@ -116,7 +115,7 @@ void Work::finish(std::exception_ptr exception) {
|
||||
}
|
||||
|
||||
void Work::finishAndThrow(std::exception_ptr exception) {
|
||||
C10D_LOCK_GUARD(lock, mutex_);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
completed_ = true;
|
||||
exception_ = std::move(exception);
|
||||
if (recordFunctionEndCallback_) {
|
||||
|
@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
||||
// provided by the user.
|
||||
void finishAndThrow(std::exception_ptr exception);
|
||||
|
||||
mutable std::timed_mutex mutex_;
|
||||
std::condition_variable_any cv_;
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool completed_ = false;
|
||||
std::exception_ptr exception_;
|
||||
|
||||
|
@ -34,20 +34,4 @@ bool isLogLevelEnabled(LogLevel level) noexcept {
|
||||
return false;
|
||||
}
|
||||
|
||||
void lockWithLogging(
|
||||
std::unique_lock<std::timed_mutex>& lock,
|
||||
std::chrono::milliseconds log_interval,
|
||||
c10::string_view desc,
|
||||
c10::string_view file,
|
||||
int line) {
|
||||
while (!lock.try_lock_for(log_interval)) {
|
||||
C10D_WARNING(
|
||||
"{}:{} {}: waiting for lock for {}ms",
|
||||
file,
|
||||
line,
|
||||
desc,
|
||||
log_interval.count());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10d::detail
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
@ -25,16 +24,6 @@ std::string formatLogMessage(fmt::string_view fmt, T&&... args) {
|
||||
return fmt::vformat(fmt, fmt::make_format_args(args...));
|
||||
}
|
||||
|
||||
// logWithLogging is a wrapper around std::unique_lock<std::timed_mutex>
|
||||
// that automatically logs if the lock cannot be acquired within a given
|
||||
// timeout.
|
||||
TORCH_API void lockWithLogging(
|
||||
std::unique_lock<std::timed_mutex>& lock,
|
||||
std::chrono::milliseconds log_interval,
|
||||
c10::string_view desc,
|
||||
c10::string_view file,
|
||||
int line);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace c10d
|
||||
|
||||
@ -60,9 +49,3 @@ TORCH_API void lockWithLogging(
|
||||
#define C10D_TRACE(...) \
|
||||
LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \
|
||||
<< "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__)
|
||||
|
||||
// TODO: use std::source_location() when we can use C++20
|
||||
#define C10D_LOCK_GUARD(name, mutex) \
|
||||
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock}; \
|
||||
::c10d::detail::lockWithLogging( \
|
||||
name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__)
|
||||
|
Reference in New Issue
Block a user