mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
c10d/logging: add C10D_LOCK_GUARD (#134131)
This adds logs if we can't acquire locks in NCCLUtils and ProcessGroupNCCL for 30s. This is motivated by some deadlocks were seeing and it's unclear if it's in NCCL or on the PyTorch side of things. This required replacing most `std::mutex` with `std::timed_mutex` and `std::condition_variable_any` as appropriate. Test plan: existing CI for regressions will add unit tests on `C10D_LOCK_GUARD` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134131 Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
c45ca8092d
commit
f33bcbe5fd
@ -495,6 +495,7 @@ libtorch_distributed_base_sources = [
|
|||||||
"torch/csrc/distributed/c10d/Functional.cpp",
|
"torch/csrc/distributed/c10d/Functional.cpp",
|
||||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||||
"torch/csrc/distributed/c10d/GroupRegistry.cpp",
|
"torch/csrc/distributed/c10d/GroupRegistry.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/LockGuard.cpp",
|
||||||
"torch/csrc/distributed/c10d/Ops.cpp",
|
"torch/csrc/distributed/c10d/Ops.cpp",
|
||||||
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
||||||
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
||||||
|
|||||||
@ -927,6 +927,7 @@ elseif(USE_CUDA)
|
|||||||
set(CUDA_LINK_LIBRARIES_KEYWORD)
|
set(CUDA_LINK_LIBRARIES_KEYWORD)
|
||||||
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
|
torch_compile_options(torch_cuda) # see cmake/public/utils.cmake
|
||||||
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
|
target_compile_definitions(torch_cuda PRIVATE USE_CUDA)
|
||||||
|
target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only)
|
||||||
|
|
||||||
if(USE_CUFILE)
|
if(USE_CUFILE)
|
||||||
target_link_libraries(torch_cuda PRIVATE torch::cufile)
|
target_link_libraries(torch_cuda PRIVATE torch::cufile)
|
||||||
@ -1326,6 +1327,7 @@ if(USE_ROCM)
|
|||||||
${ROCM_SOURCE_DIR}/rocblas/include
|
${ROCM_SOURCE_DIR}/rocblas/include
|
||||||
${ROCM_SOURCE_DIR}/hipsparse/include
|
${ROCM_SOURCE_DIR}/hipsparse/include
|
||||||
)
|
)
|
||||||
|
target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only)
|
||||||
if(USE_FLASH_ATTENTION)
|
if(USE_FLASH_ATTENTION)
|
||||||
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
|
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@ -10,12 +10,14 @@ function(c10d_add_test test_src)
|
|||||||
add_executable(${test_name} "${test_src}")
|
add_executable(${test_name} "${test_src}")
|
||||||
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>)
|
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>)
|
||||||
target_link_libraries(${test_name} ${ARGN})
|
target_link_libraries(${test_name} ${ARGN})
|
||||||
|
target_link_libraries(${test_name} fmt::fmt-header-only)
|
||||||
if(NOT WIN32)
|
if(NOT WIN32)
|
||||||
target_link_libraries(${test_name} pthread)
|
target_link_libraries(${test_name} pthread)
|
||||||
endif()
|
endif()
|
||||||
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
|
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
c10d_add_test(LoggingTest.cpp torch_cpu gtest_main)
|
||||||
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
|
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
|
||||||
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
|
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
|
||||||
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)
|
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)
|
||||||
|
|||||||
54
test/cpp/c10d/LoggingTest.cpp
Normal file
54
test/cpp/c10d/LoggingTest.cpp
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <future>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include <c10/util/Logging.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
|
||||||
|
|
||||||
|
TEST(LockGuard, basic) {
|
||||||
|
std::timed_mutex mutex;
|
||||||
|
|
||||||
|
{
|
||||||
|
C10D_LOCK_GUARD(lock, mutex);
|
||||||
|
|
||||||
|
// already locked
|
||||||
|
ASSERT_FALSE(mutex.try_lock());
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(mutex.try_lock());
|
||||||
|
mutex.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LockGuard, logging) {
|
||||||
|
// set log level to INFO
|
||||||
|
FLAGS_caffe2_log_level = 0;
|
||||||
|
|
||||||
|
std::timed_mutex mutex;
|
||||||
|
|
||||||
|
mutex.lock();
|
||||||
|
|
||||||
|
auto loggingThread = std::async(std::launch::async, [&]() {
|
||||||
|
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock};
|
||||||
|
::c10d::detail::lockWithLogging(
|
||||||
|
name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10);
|
||||||
|
while (true) {
|
||||||
|
ASSERT_LT(std::chrono::system_clock::now(), deadline);
|
||||||
|
|
||||||
|
testing::internal::CaptureStderr();
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(20));
|
||||||
|
std::string output = testing::internal::GetCapturedStderr();
|
||||||
|
|
||||||
|
if (output.find("my lock: waiting for lock for 10ms") !=
|
||||||
|
std::string::npos) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mutex.unlock();
|
||||||
|
|
||||||
|
loggingThread.get();
|
||||||
|
}
|
||||||
@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught
|
|||||||
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
|
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
|
||||||
hasMonitorThreadCaughtError_(false) {}
|
hasMonitorThreadCaughtError_(false) {}
|
||||||
|
|
||||||
std::mutex& getWatchdogMutex() {
|
std::timed_mutex& getWatchdogMutex() {
|
||||||
return workMetaListMutex_;
|
return workMetaListMutex_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
|||||||
work = pg.allreduce(tensors_);
|
work = pg.allreduce(tensors_);
|
||||||
{
|
{
|
||||||
// Now run all reduce with errors.
|
// Now run all reduce with errors.
|
||||||
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
|
std::lock_guard<std::timed_mutex> lock(pg.getWatchdogMutex());
|
||||||
LOG(INFO) << "Lock watchdog thread.";
|
LOG(INFO) << "Lock watchdog thread.";
|
||||||
// Wait long enough before monitor thread throws exceptions.
|
// Wait long enough before monitor thread throws exceptions.
|
||||||
std::this_thread::sleep_for(
|
std::this_thread::sleep_for(
|
||||||
|
|||||||
29
torch/csrc/distributed/c10d/LockGuard.cpp
Normal file
29
torch/csrc/distributed/c10d/LockGuard.cpp
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the BSD-style license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
|
||||||
|
|
||||||
|
#include <torch/csrc/distributed/c10d/logging.h>
|
||||||
|
|
||||||
|
namespace c10d::detail {
|
||||||
|
|
||||||
|
void lockWithLogging(
|
||||||
|
std::unique_lock<std::timed_mutex>& lock,
|
||||||
|
std::chrono::milliseconds log_interval,
|
||||||
|
const char* desc,
|
||||||
|
const char* file,
|
||||||
|
int line) {
|
||||||
|
while (!lock.try_lock_for(log_interval)) {
|
||||||
|
C10D_WARNING(
|
||||||
|
"{}:{} {}: waiting for lock for {}ms",
|
||||||
|
file,
|
||||||
|
line,
|
||||||
|
desc,
|
||||||
|
log_interval.count());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10d::detail
|
||||||
32
torch/csrc/distributed/c10d/LockGuard.hpp
Normal file
32
torch/csrc/distributed/c10d/LockGuard.hpp
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the BSD-style license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
|
||||||
|
namespace c10d::detail {
|
||||||
|
|
||||||
|
// logWithLogging is a wrapper around std::unique_lock<std::timed_mutex>
|
||||||
|
// that automatically logs if the lock cannot be acquired within a given
|
||||||
|
// timeout.
|
||||||
|
TORCH_API void lockWithLogging(
|
||||||
|
std::unique_lock<std::timed_mutex>& lock,
|
||||||
|
std::chrono::milliseconds log_interval,
|
||||||
|
const char* desc,
|
||||||
|
const char* file,
|
||||||
|
int line);
|
||||||
|
|
||||||
|
} // namespace c10d::detail
|
||||||
|
|
||||||
|
// TODO: use std::source_location() when we can use C++20
|
||||||
|
#define C10D_LOCK_GUARD(name, mutex) \
|
||||||
|
std::unique_lock<std::timed_mutex> name{mutex, std::defer_lock}; \
|
||||||
|
::c10d::detail::lockWithLogging( \
|
||||||
|
name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__)
|
||||||
@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10;
|
|||||||
namespace c10d {
|
namespace c10d {
|
||||||
|
|
||||||
ncclComm_t NCCLComm::getNcclComm() {
|
ncclComm_t NCCLComm::getNcclComm() {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
if (aborted_) {
|
if (aborted_) {
|
||||||
auto commFailureMsg = commFailureReason_ != std::nullopt
|
auto commFailureMsg = commFailureReason_ != std::nullopt
|
||||||
? c10::str(" Original reason for failure was: ", *commFailureReason_)
|
? c10::str(" Original reason for failure was: ", *commFailureReason_)
|
||||||
@ -391,7 +391,7 @@ std::optional<size_t> NCCLTraceBuffer::record(
|
|||||||
}
|
}
|
||||||
auto traceback =
|
auto traceback =
|
||||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||||
std::lock_guard<std::mutex> guard(mutex_);
|
C10D_LOCK_GUARD(guard, mutex_);
|
||||||
|
|
||||||
auto te = Entry{
|
auto te = Entry{
|
||||||
id_,
|
id_,
|
||||||
@ -448,7 +448,7 @@ void NCCLTraceBuffer::record_pg_ranks(
|
|||||||
if (!enabled_) {
|
if (!enabled_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::lock_guard<std::mutex> guard(mutex_);
|
C10D_LOCK_GUARD(guard, mutex_);
|
||||||
pg_name_to_ranks_[pg_name] = ranks;
|
pg_name_to_ranks_[pg_name] = ranks;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -468,7 +468,7 @@ void NCCLTraceBuffer::update_state(Entry& r) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
|
std::vector<NCCLTraceBuffer::Entry> NCCLTraceBuffer::dump_entries() {
|
||||||
std::lock_guard<std::mutex> guard(mutex_);
|
C10D_LOCK_GUARD(guard, mutex_);
|
||||||
std::vector<Entry> result;
|
std::vector<Entry> result;
|
||||||
result.reserve(entries_.size());
|
result.reserve(entries_.size());
|
||||||
result.insert(result.end(), entries_.begin() + next_, entries_.end());
|
result.insert(result.end(), entries_.begin() + next_, entries_.end());
|
||||||
@ -493,7 +493,7 @@ void NCCLTraceBuffer::retire_id(
|
|||||||
Event* endEvent = nullptr;
|
Event* endEvent = nullptr;
|
||||||
std::optional<float> duration = std::nullopt;
|
std::optional<float> duration = std::nullopt;
|
||||||
|
|
||||||
std::unique_lock<std::mutex> guard(mutex_);
|
C10D_LOCK_GUARD(guard, mutex_);
|
||||||
|
|
||||||
Entry* entry = &entries_.at(*id % max_entries_);
|
Entry* entry = &entries_.at(*id % max_entries_);
|
||||||
if (entry->id_ == *id) {
|
if (entry->id_ == *id) {
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include <ATen/cuda/CUDAEvent.h>
|
#include <ATen/cuda/CUDAEvent.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <nccl.h>
|
#include <nccl.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/TraceUtils.h>
|
#include <torch/csrc/distributed/c10d/TraceUtils.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
@ -271,7 +272,7 @@ class NCCLComm {
|
|||||||
~NCCLComm() noexcept {
|
~NCCLComm() noexcept {
|
||||||
// Add lock in this destructor, as aborted_ needs to be read after memory
|
// Add lock in this destructor, as aborted_ needs to be read after memory
|
||||||
// barrier here.
|
// barrier here.
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
if (ncclComm_ && initialized_ && !aborted_) {
|
if (ncclComm_ && initialized_ && !aborted_) {
|
||||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||||
// Use ncclCommAbort instead of ncclCommDestroy here since
|
// Use ncclCommAbort instead of ncclCommDestroy here since
|
||||||
@ -363,7 +364,7 @@ class NCCLComm {
|
|||||||
NCCLComm(NCCLComm&& other) {
|
NCCLComm(NCCLComm&& other) {
|
||||||
// Using other's lock, as it reads other's states
|
// Using other's lock, as it reads other's states
|
||||||
// Can not use this.mutex_, as this object is being constructed.
|
// Can not use this.mutex_, as this object is being constructed.
|
||||||
std::unique_lock<std::mutex> lock(other.mutex_);
|
C10D_LOCK_GUARD(lock, other.mutex_);
|
||||||
std::swap(ncclComm_, other.ncclComm_);
|
std::swap(ncclComm_, other.ncclComm_);
|
||||||
std::swap(aborted_, other.aborted_);
|
std::swap(aborted_, other.aborted_);
|
||||||
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||||
@ -373,13 +374,13 @@ class NCCLComm {
|
|||||||
ncclComm_t getNcclComm();
|
ncclComm_t getNcclComm();
|
||||||
|
|
||||||
std::optional<std::string> getNcclCommFailureReason() const {
|
std::optional<std::string> getNcclCommFailureReason() const {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
return commFailureReason_;
|
return commFailureReason_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ncclCommAbort(
|
void ncclCommAbort(
|
||||||
std::optional<std::string> commFailureReason = std::nullopt) {
|
std::optional<std::string> commFailureReason = std::nullopt) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||||
if (aborted_ && !initialized_) {
|
if (aborted_ && !initialized_) {
|
||||||
// Should not abort twice.
|
// Should not abort twice.
|
||||||
@ -427,7 +428,7 @@ class NCCLComm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool isAborted() const {
|
bool isAborted() const {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
return aborted_;
|
return aborted_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -436,7 +437,7 @@ class NCCLComm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t checkForNcclError() {
|
ncclResult_t checkForNcclError() {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||||
if (ncclAsyncErr_ != ncclSuccess) {
|
if (ncclAsyncErr_ != ncclSuccess) {
|
||||||
return ncclAsyncErr_;
|
return ncclAsyncErr_;
|
||||||
@ -451,7 +452,7 @@ class NCCLComm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t registerSegment(void* ptr, size_t size) {
|
ncclResult_t registerSegment(void* ptr, size_t size) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
#ifdef NCCL_HAS_COMM_REGISTER
|
#ifdef NCCL_HAS_COMM_REGISTER
|
||||||
// We register only segments from cache allocator
|
// We register only segments from cache allocator
|
||||||
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
|
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
|
||||||
@ -482,7 +483,7 @@ class NCCLComm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ncclResult_t deregisterSegment(void* ptr) {
|
ncclResult_t deregisterSegment(void* ptr) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
#ifdef NCCL_HAS_COMM_REGISTER
|
#ifdef NCCL_HAS_COMM_REGISTER
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
registeredSegmentHandles_.count(ptr) == 1,
|
registeredSegmentHandles_.count(ptr) == 1,
|
||||||
@ -519,7 +520,7 @@ class NCCLComm {
|
|||||||
bool aborted_;
|
bool aborted_;
|
||||||
uint64_t ncclCommSplitCounter_{0};
|
uint64_t ncclCommSplitCounter_{0};
|
||||||
ncclResult_t ncclAsyncErr_;
|
ncclResult_t ncclAsyncErr_;
|
||||||
mutable std::mutex mutex_;
|
mutable std::timed_mutex mutex_;
|
||||||
// Rank that this communicator corresponds to.
|
// Rank that this communicator corresponds to.
|
||||||
int rank_;
|
int rank_;
|
||||||
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
|
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
|
||||||
@ -638,7 +639,7 @@ struct NCCLTraceBuffer {
|
|||||||
|
|
||||||
bool enabled_ = false;
|
bool enabled_ = false;
|
||||||
bool capture_cpp_stack_ = false;
|
bool capture_cpp_stack_ = false;
|
||||||
std::mutex mutex_;
|
std::timed_mutex mutex_;
|
||||||
std::vector<Entry> entries_;
|
std::vector<Entry> entries_;
|
||||||
size_t max_entries_ = 0;
|
size_t max_entries_ = 0;
|
||||||
size_t next_ = 0;
|
size_t next_ = 0;
|
||||||
|
|||||||
@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int ProcessGroupGloo::RecvWork::sourceRank() const {
|
int ProcessGroupGloo::RecvWork::sourceRank() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::timed_mutex> lock(mutex_);
|
||||||
return srcRank_;
|
return srcRank_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@
|
|||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/util/thread_name.h>
|
#include <c10/util/thread_name.h>
|
||||||
#include <torch/csrc/cuda/nccl.h>
|
#include <torch/csrc/cuda/nccl.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
|
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||||
@ -301,7 +302,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
|
|||||||
// hooks are called outside the scope of any PG, thus we need traverse
|
// hooks are called outside the scope of any PG, thus we need traverse
|
||||||
// communicators in all PGs.
|
// communicators in all PGs.
|
||||||
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
|
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
|
||||||
static std::mutex ncclCommDevIdxMapMutex;
|
static std::timed_mutex ncclCommDevIdxMapMutex;
|
||||||
static bool allocatorHooksAttached = false;
|
static bool allocatorHooksAttached = false;
|
||||||
|
|
||||||
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
|
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
|
||||||
@ -314,7 +315,7 @@ void cacheAllocatorRegisterHook(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
|
||||||
for (auto& it : ncclCommDevIdxMap) {
|
for (auto& it : ncclCommDevIdxMap) {
|
||||||
auto& ncclComm = it.first;
|
auto& ncclComm = it.first;
|
||||||
auto& devIdx = it.second;
|
auto& devIdx = it.second;
|
||||||
@ -332,7 +333,7 @@ void cacheAllocatorDeregisterHook(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex);
|
||||||
for (auto& it : ncclCommDevIdxMap) {
|
for (auto& it : ncclCommDevIdxMap) {
|
||||||
auto& ncclComm = it.first;
|
auto& ncclComm = it.first;
|
||||||
auto& devIdx = it.second;
|
auto& devIdx = it.second;
|
||||||
@ -551,7 +552,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto exception_ptr = checkForNCCLErrors();
|
auto exception_ptr = checkForNCCLErrors();
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
exception_ = exception_ptr;
|
exception_ = exception_ptr;
|
||||||
if (exception_) {
|
if (exception_) {
|
||||||
LOG(ERROR) << logPrefix() << "Collective " << *this
|
LOG(ERROR) << logPrefix() << "Collective " << *this
|
||||||
@ -567,7 +568,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
|
|||||||
|
|
||||||
void ProcessGroupNCCL::WorkNCCL::setException(
|
void ProcessGroupNCCL::WorkNCCL::setException(
|
||||||
std::exception_ptr exception_ptr) {
|
std::exception_ptr exception_ptr) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
exception_ = exception_ptr;
|
exception_ = exception_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -776,12 +777,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {}
|
|||||||
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
|
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
|
||||||
bool timing) {
|
bool timing) {
|
||||||
auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
|
auto deleter = [this, timing](at::cuda::CUDAEvent* event) {
|
||||||
std::lock_guard<std::mutex> lock(this->cacheMutex_);
|
C10D_LOCK_GUARD(lock, this->cacheMutex_);
|
||||||
this->eventsArray_[timing ? 1 : 0].push_back(event);
|
this->eventsArray_[timing ? 1 : 0].push_back(event);
|
||||||
};
|
};
|
||||||
at::cuda::CUDAEvent* event = nullptr;
|
at::cuda::CUDAEvent* event = nullptr;
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(cacheMutex_);
|
C10D_LOCK_GUARD(lock, cacheMutex_);
|
||||||
auto events = eventsArray_[timing ? 1 : 0];
|
auto events = eventsArray_[timing ? 1 : 0];
|
||||||
if (!events.empty()) {
|
if (!events.empty()) {
|
||||||
event = events.back();
|
event = events.back();
|
||||||
@ -1086,8 +1087,9 @@ void ProcessGroupNCCL::waitForPendingWorks() {
|
|||||||
while (true) {
|
while (true) {
|
||||||
{
|
{
|
||||||
std::lock(workMetaListMutex_, completedWorkListMutex_);
|
std::lock(workMetaListMutex_, completedWorkListMutex_);
|
||||||
std::lock_guard<std::mutex> lockWork(workMetaListMutex_, std::adopt_lock);
|
std::lock_guard<std::timed_mutex> lockWork(
|
||||||
std::lock_guard<std::mutex> lockHook(
|
workMetaListMutex_, std::adopt_lock);
|
||||||
|
std::lock_guard<std::timed_mutex> lockHook(
|
||||||
completedWorkListMutex_, std::adopt_lock);
|
completedWorkListMutex_, std::adopt_lock);
|
||||||
|
|
||||||
if (workMetaList_.empty() && completedWorkList_.empty()) {
|
if (workMetaList_.empty() && completedWorkList_.empty()) {
|
||||||
@ -1207,7 +1209,7 @@ bool ProcessGroupNCCL::abort(std::optional<std::string> abortReason) {
|
|||||||
}
|
}
|
||||||
ncclCommDevIdxMapMutex.unlock();
|
ncclCommDevIdxMapMutex.unlock();
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
abortCommsFromMap(devNCCLCommMap_, abortReason);
|
abortCommsFromMap(devNCCLCommMap_, abortReason);
|
||||||
abortCommsFromMap(inInitializationCommMap_, abortReason);
|
abortCommsFromMap(inInitializationCommMap_, abortReason);
|
||||||
return true;
|
return true;
|
||||||
@ -1276,8 +1278,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() {
|
|||||||
// Serialize all calls to this function to avoid corrupting data, but allow
|
// Serialize all calls to this function to avoid corrupting data, but allow
|
||||||
// multiple calls in one runtime. User is responsible for preserving the
|
// multiple calls in one runtime. User is responsible for preserving the
|
||||||
// output file from an earlier call before a later call overwrites it.
|
// output file from an earlier call before a later call overwrites it.
|
||||||
static std::mutex writeDebugInfoMutex;
|
static std::timed_mutex writeDebugInfoMutex;
|
||||||
std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
|
C10D_LOCK_GUARD(lock, writeDebugInfoMutex);
|
||||||
LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
|
LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
|
||||||
if (ncclTraceBufferSize_ > 0) {
|
if (ncclTraceBufferSize_ > 0) {
|
||||||
// We dump nccl trace into local disk by default and users can register
|
// We dump nccl trace into local disk by default and users can register
|
||||||
@ -1356,7 +1358,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
|||||||
// This won't have any lock since this lock is only used here.
|
// This won't have any lock since this lock is only used here.
|
||||||
// Please be aware that mutex `monitorMutex_` should not be used
|
// Please be aware that mutex `monitorMutex_` should not be used
|
||||||
// somewhere else to avoid the deadlock.
|
// somewhere else to avoid the deadlock.
|
||||||
std::unique_lock<std::mutex> lock(monitorMutex_);
|
C10D_LOCK_GUARD(lock, monitorMutex_);
|
||||||
if (monitorWakeUpCV_.wait_for(
|
if (monitorWakeUpCV_.wait_for(
|
||||||
lock, std::chrono::milliseconds(monitorPollInterval), [&] {
|
lock, std::chrono::milliseconds(monitorPollInterval), [&] {
|
||||||
return terminateHeartbeatMonitorThread_.load();
|
return terminateHeartbeatMonitorThread_.load();
|
||||||
@ -1681,7 +1683,7 @@ const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
|
|||||||
|
|
||||||
void ProcessGroupNCCL::addEphemeralTimeout(
|
void ProcessGroupNCCL::addEphemeralTimeout(
|
||||||
const std::chrono::milliseconds& timeout) {
|
const std::chrono::milliseconds& timeout) {
|
||||||
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
|
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
|
||||||
ephemeralTimeoutActive_ += timeout;
|
ephemeralTimeoutActive_ += timeout;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1704,7 +1706,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
|||||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
|
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
|
||||||
|
|
||||||
while (!done || !terminateProcessGroup_.load()) {
|
while (!done || !terminateProcessGroup_.load()) {
|
||||||
std::unique_lock<std::mutex> lock(workMetaListMutex_);
|
C10D_LOCK_GUARD(lock, workMetaListMutex_);
|
||||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||||
// milliseconds as long as the atomic is True.
|
// milliseconds as long as the atomic is True.
|
||||||
workMetaListCV_.wait_for(
|
workMetaListCV_.wait_for(
|
||||||
@ -1876,7 +1878,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
|||||||
if (work.isCompleted()) {
|
if (work.isCompleted()) {
|
||||||
{
|
{
|
||||||
// Reset the timeout and first work if the work is completed.
|
// Reset the timeout and first work if the work is completed.
|
||||||
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
|
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
|
||||||
if (work.ownedEphermeralTimeout_.count() > 0) {
|
if (work.ownedEphermeralTimeout_.count() > 0) {
|
||||||
ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
|
ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
|
||||||
ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
|
ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
|
||||||
@ -1891,7 +1893,7 @@ void ProcessGroupNCCL::watchdogHandler() {
|
|||||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||||
// thread
|
// thread
|
||||||
{
|
{
|
||||||
const std::lock_guard<std::mutex> lock(completedWorkListMutex_);
|
C10D_LOCK_GUARD(lock, completedWorkListMutex_);
|
||||||
completedWorkList_.splice(
|
completedWorkList_.splice(
|
||||||
completedWorkList_.end(), workMetaList_, it++);
|
completedWorkList_.end(), workMetaList_, it++);
|
||||||
}
|
}
|
||||||
@ -1919,7 +1921,7 @@ void ProcessGroupNCCL::runHookLoop() {
|
|||||||
|
|
||||||
bool done = false;
|
bool done = false;
|
||||||
while (!done || !terminateProcessGroup_.load()) {
|
while (!done || !terminateProcessGroup_.load()) {
|
||||||
std::unique_lock<std::mutex> lock(completedWorkListMutex_);
|
C10D_LOCK_GUARD(lock, completedWorkListMutex_);
|
||||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||||
// milliseconds as long as the atomic is True.
|
// milliseconds as long as the atomic is True.
|
||||||
completedWorkListCV_.wait_for(
|
completedWorkListCV_.wait_for(
|
||||||
@ -2092,7 +2094,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
|
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
false,
|
false,
|
||||||
@ -2140,7 +2142,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
|||||||
usedDeviceIdxs_.insert(device.index());
|
usedDeviceIdxs_.insert(device.index());
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
|
if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) {
|
||||||
// Reuse the cached communicator if there is one.
|
// Reuse the cached communicator if there is one.
|
||||||
return devNCCLCommMap_[deviceKey];
|
return devNCCLCommMap_[deviceKey];
|
||||||
@ -2214,7 +2216,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
|||||||
options_->split_color != 0,
|
options_->split_color != 0,
|
||||||
"Must specify a non-zero color when splitting");
|
"Must specify a non-zero color when splitting");
|
||||||
// Find a valid, healthy communicator to split from if possible.
|
// Find a valid, healthy communicator to split from if possible.
|
||||||
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
|
C10D_LOCK_GUARD(lock, options_->split_from->mutex_);
|
||||||
auto& other_comms = options_->split_from->devNCCLCommMap_;
|
auto& other_comms = options_->split_from->devNCCLCommMap_;
|
||||||
auto dit = other_comms.find(getKeyFromDevice(device));
|
auto dit = other_comms.find(getKeyFromDevice(device));
|
||||||
if (dit != other_comms.end()) {
|
if (dit != other_comms.end()) {
|
||||||
@ -2268,7 +2270,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
|
|||||||
options_->is_high_priority_stream || force_high);
|
options_->is_high_priority_stream || force_high);
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
inInitializationCommMap_.emplace(deviceKey, ncclComm);
|
inInitializationCommMap_.emplace(deviceKey, ncclComm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2518,7 +2520,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
|
|||||||
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
|
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
|
||||||
const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
|
const c10::intrusive_ptr<ProcessGroupNCCL::Options>& option) {
|
||||||
std::chrono::milliseconds timeout = option->timeout;
|
std::chrono::milliseconds timeout = option->timeout;
|
||||||
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
|
C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_);
|
||||||
if (ephemeralTimeoutActive_.count() > 0) {
|
if (ephemeralTimeoutActive_.count() > 0) {
|
||||||
timeout += ephemeralTimeoutActive_;
|
timeout += ephemeralTimeoutActive_;
|
||||||
}
|
}
|
||||||
@ -2531,7 +2533,7 @@ void ProcessGroupNCCL::assignTimeoutToWork(
|
|||||||
void ProcessGroupNCCL::workEnqueue(
|
void ProcessGroupNCCL::workEnqueue(
|
||||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
|
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) {
|
||||||
if (!terminateProcessGroup_.load()) {
|
if (!terminateProcessGroup_.load()) {
|
||||||
std::lock_guard<std::mutex> lock(workMetaListMutex_);
|
C10D_LOCK_GUARD(lock, workMetaListMutex_);
|
||||||
// Avoid view tensors to be processed in cleanup thread.
|
// Avoid view tensors to be processed in cleanup thread.
|
||||||
// View tensors' destruction invokes autograd_meta, which
|
// View tensors' destruction invokes autograd_meta, which
|
||||||
// needs to be destructed in user thread. Otherwise will
|
// needs to be destructed in user thread. Otherwise will
|
||||||
|
|||||||
@ -449,7 +449,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
static CUDAEventCache& get();
|
static CUDAEventCache& get();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::mutex cacheMutex_;
|
std::timed_mutex cacheMutex_;
|
||||||
// NOTE: We intentionaly store raw pointers so that
|
// NOTE: We intentionaly store raw pointers so that
|
||||||
// we do not attempt to destroy the event objects on process exit,
|
// we do not attempt to destroy the event objects on process exit,
|
||||||
// because cuda may be gone.
|
// because cuda may be gone.
|
||||||
@ -920,7 +920,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
// ephemeralTimeoutActive_/ephemeralTimeoutInflight_.
|
// ephemeralTimeoutActive_/ephemeralTimeoutInflight_.
|
||||||
// TODO(fduwjj): We need to have an audit on all mutexes we are adding here.
|
// TODO(fduwjj): We need to have an audit on all mutexes we are adding here.
|
||||||
// And consolidate them if possible.
|
// And consolidate them if possible.
|
||||||
std::mutex mtxTimeoutExtension_;
|
std::timed_mutex mtxTimeoutExtension_;
|
||||||
|
|
||||||
// The ephemeral timeout added on top of existing timeout for works issued
|
// The ephemeral timeout added on top of existing timeout for works issued
|
||||||
// before first work finishes.
|
// before first work finishes.
|
||||||
@ -980,7 +980,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
inInitializationCommMap_;
|
inInitializationCommMap_;
|
||||||
|
|
||||||
// Mutex to guard maps like devNCCLCommMap_.
|
// Mutex to guard maps like devNCCLCommMap_.
|
||||||
std::mutex mutex_;
|
std::timed_mutex mutex_;
|
||||||
|
|
||||||
// Heartbeat of watchdog thread.
|
// Heartbeat of watchdog thread.
|
||||||
std::atomic_uint64_t heartbeat_;
|
std::atomic_uint64_t heartbeat_;
|
||||||
@ -1041,18 +1041,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
static std::atomic<bool> shouldDump_;
|
static std::atomic<bool> shouldDump_;
|
||||||
|
|
||||||
// Mutex to Guard workMetaList_
|
// Mutex to Guard workMetaList_
|
||||||
std::mutex workMetaListMutex_;
|
std::timed_mutex workMetaListMutex_;
|
||||||
|
|
||||||
// Mutex to Guard monitorWakeUpCV_
|
// Mutex to Guard monitorWakeUpCV_
|
||||||
std::mutex monitorMutex_;
|
std::timed_mutex monitorMutex_;
|
||||||
|
|
||||||
bool writeDebugInfo_ = false;
|
bool writeDebugInfo_ = false;
|
||||||
|
|
||||||
// Condition Variable for watchdog thread sleep
|
// Condition Variable for watchdog thread sleep
|
||||||
std::condition_variable workMetaListCV_;
|
std::condition_variable_any workMetaListCV_;
|
||||||
|
|
||||||
// Condition Variable for monitor thread to wake up early
|
// Condition Variable for monitor thread to wake up early
|
||||||
std::condition_variable monitorWakeUpCV_;
|
std::condition_variable_any monitorWakeUpCV_;
|
||||||
|
|
||||||
// Vector to Store WorkNCCL pointers
|
// Vector to Store WorkNCCL pointers
|
||||||
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
|
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
|
||||||
@ -1060,10 +1060,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
|
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
|
||||||
|
|
||||||
// Mutex to Guard workMetaList_
|
// Mutex to Guard workMetaList_
|
||||||
std::mutex completedWorkListMutex_;
|
std::timed_mutex completedWorkListMutex_;
|
||||||
|
|
||||||
// Condition Variable for watchdog thread sleep
|
// Condition Variable for watchdog thread sleep
|
||||||
std::condition_variable completedWorkListCV_;
|
std::condition_variable_any completedWorkListCV_;
|
||||||
|
|
||||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
|
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#include <ATen/ThreadLocalState.h>
|
#include <ATen/ThreadLocalState.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/distributed/c10d/LockGuard.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@ -45,17 +46,17 @@ OpType Work::retrieveOpType() const {
|
|||||||
Work::~Work() = default;
|
Work::~Work() = default;
|
||||||
|
|
||||||
bool Work::isCompleted() {
|
bool Work::isCompleted() {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
return completed_;
|
return completed_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Work::isSuccess() const {
|
bool Work::isSuccess() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
return !exception_;
|
return !exception_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::exception_ptr Work::exception() const {
|
std::exception_ptr Work::exception() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
return exception_;
|
return exception_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ std::vector<at::Tensor> Work::result() {
|
|||||||
void Work::synchronize() {}
|
void Work::synchronize() {}
|
||||||
|
|
||||||
bool Work::wait(std::chrono::milliseconds timeout) {
|
bool Work::wait(std::chrono::milliseconds timeout) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
if (timeout == kNoTimeout) {
|
if (timeout == kNoTimeout) {
|
||||||
// This waits without a timeout.
|
// This waits without a timeout.
|
||||||
cv_.wait(lock, [&] { return completed_; });
|
cv_.wait(lock, [&] { return completed_; });
|
||||||
@ -103,7 +104,7 @@ c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Work::finish(std::exception_ptr exception) {
|
void Work::finish(std::exception_ptr exception) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
completed_ = true;
|
completed_ = true;
|
||||||
exception_ = std::move(exception);
|
exception_ = std::move(exception);
|
||||||
if (recordFunctionEndCallback_) {
|
if (recordFunctionEndCallback_) {
|
||||||
@ -115,7 +116,7 @@ void Work::finish(std::exception_ptr exception) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Work::finishAndThrow(std::exception_ptr exception) {
|
void Work::finishAndThrow(std::exception_ptr exception) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
C10D_LOCK_GUARD(lock, mutex_);
|
||||||
completed_ = true;
|
completed_ = true;
|
||||||
exception_ = std::move(exception);
|
exception_ = std::move(exception);
|
||||||
if (recordFunctionEndCallback_) {
|
if (recordFunctionEndCallback_) {
|
||||||
|
|||||||
@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
|||||||
// provided by the user.
|
// provided by the user.
|
||||||
void finishAndThrow(std::exception_ptr exception);
|
void finishAndThrow(std::exception_ptr exception);
|
||||||
|
|
||||||
mutable std::mutex mutex_;
|
mutable std::timed_mutex mutex_;
|
||||||
std::condition_variable cv_;
|
std::condition_variable_any cv_;
|
||||||
bool completed_ = false;
|
bool completed_ = false;
|
||||||
std::exception_ptr exception_;
|
std::exception_ptr exception_;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user