From a33eac398881cfa9aad679ceffd28ace3fa44f01 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 12 Apr 2023 18:33:10 +0000 Subject: [PATCH] [NCCL] Add experimental Nonblocking NCCL Fault Tolerance/Checking (#95715) Support for nonblocking NCCL communicators/fault tolerance/checking which was added in 2.14 as an experimental feature. Enabled via the environment variable: ``` TORCH_NCCL_USE_COMM_NONBLOCKING=1 ``` CC @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/95715 Approved by: https://github.com/kwen2501 --- torch/_C/_distributed_c10d.pyi | 6 +- torch/csrc/cuda/nccl.cpp | 181 +++++++++++++++++- torch/csrc/cuda/nccl.h | 6 +- torch/csrc/distributed/c10d/NCCLUtils.cpp | 30 +++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 76 +++++++- .../distributed/c10d/ProcessGroupNCCL.cpp | 92 ++++++++- .../distributed/c10d/ProcessGroupNCCL.hpp | 6 +- torch/csrc/distributed/c10d/init.cpp | 6 +- 8 files changed, 384 insertions(+), 19 deletions(-) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index b7dab12fdc43..840216318639 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -408,10 +408,8 @@ class ProcessGroupNCCL(ProcessGroup): size: int, timeout: timedelta, ): ... - @staticmethod - def _group_start() -> None: ... - @staticmethod - def _group_end() -> None: ... + def _group_start(self) -> None: ... + def _group_end(self) -> None: ... class ProcessGroupUCC(ProcessGroup): def __init__( diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index dcc3a4560203..441ad3205c0c 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -16,6 +16,11 @@ #include #include +#if !defined(USE_ROCM) && \ + ((NCCL_MACJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14))) +#define NCCL_HAS_COMM_NONBLOCKING 1 +#endif + ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) { return reinterpret_cast(var); } @@ -44,6 +49,10 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) { return ncclResult_t::ncclInvalidUsage; case torch::cuda::nccl::ncclResult::NumResults: return ncclResult_t::ncclNumResults; +#ifdef NCCL_HAS_COMM_NONBLOCKING + case torch::cuda::nccl::ncclResult::InProgress: + return ncclResult_t::ncclInProgress; +#endif default: throw std::runtime_error("Unconvertible NCCL type"); } @@ -65,6 +74,10 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) { return torch::cuda::nccl::ncclResult::InvalidUsage; case ncclNumResults: return torch::cuda::nccl::ncclResult::NumResults; +#ifdef NCCL_HAS_COMM_NONBLOCKING + case ncclInProgress: + return torch::cuda::nccl::ncclResult::InProgress; +#endif default: throw std::runtime_error("Unconvertible NCCL type"); } @@ -123,6 +136,105 @@ static inline void NCCL_CHECK(ncclResult_t result) { NCCL_CHECK(from_nccl_result(result)); } +// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp? +bool nccl_use_nonblocking() { + static bool nccl_use_nonblocking_ = + c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true; + if (nccl_use_nonblocking_) { + TORCH_WARN("Using experimental non-blocking NCCL communicator."); + } + return nccl_use_nonblocking_; +} + +static int _parse_nccl_nonblocking_timeout() { + const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); + int timeout = -1; + if (val) { + const std::string config(val); + timeout = std::stoi(config); + if (!nccl_use_nonblocking() && timeout > 0) { + TORCH_WARN( + "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); + timeout = -1; + } + } + return timeout; +} + +static int nccl_nonblocking_timeout() { + static int timeout = _parse_nccl_nonblocking_timeout(); + return timeout; +} + +static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) { +#ifdef NCCL_HAS_COMM_NONBLOCKING + ncclResult_t result = to_nccl_result(status); + auto startTimepoint = std::chrono::steady_clock::now(); + while (result == ncclInProgress) { + if (nccl_nonblocking_timeout() > 0) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - startTimepoint) + .count(); + if (timeElapsed > nccl_nonblocking_timeout()) { + throw std::runtime_error("NCCL timeout."); + } + } + ncclCommGetAsyncError(to_nccl_comm(comm), &result); + } + if (result != ncclSuccess) { + throw_nccl_error(from_nccl_result(result)); + } +#else + TORCH_INTERNAL_ASSERT( + false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION."); +#endif +} + +static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) { + NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm); +} + +static inline void NCCL_CHECK_TIMEOUT( + ncclResult status, + std::vector& comms) { +#ifdef NCCL_HAS_COMM_NONBLOCKING + ncclResult_t result = to_nccl_result(status); + auto startTimepoint = std::chrono::steady_clock::now(); + if (result == ncclInProgress) { + for (const auto i : c10::irange(comms.size())) { + do { + if (nccl_nonblocking_timeout() > 0) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - startTimepoint) + .count(); + if (timeElapsed > nccl_nonblocking_timeout()) { + throw std::runtime_error("NCCL timeout."); + } + } + ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result); + } while (result == ncclInProgress); + if (result != ncclSuccess) { + break; /* fall through to failed case */ + } + } + } + if (result != ncclSuccess) { + throw_nccl_error(from_nccl_result(result)); + } +#else + TORCH_INTERNAL_ASSERT( + false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION."); +#endif +} + +static inline void NCCL_CHECK_TIMEOUT( + ncclResult_t result, + std::vector& comms) { + NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms); +} + void throw_nccl_error(torch::cuda::nccl::ncclResult status) { std::ostringstream err; err << "NCCL Error " << static_cast(status) << ": " @@ -308,9 +420,25 @@ AutoNcclGroup::AutoNcclGroup() { #endif } +AutoNcclGroup::AutoNcclGroup( + std::vector& comms, + bool comm_nonblocking) { + (c10::cuda::getFreeMutex())->lock(); + // TODO(eqy): can we make comms_ reference? + comms_ = comms; + comm_nonblocking_ = comm_nonblocking; +#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) + detail::NCCL_CHECK(ncclGroupStart()); +#endif +} + AutoNcclGroup::~AutoNcclGroup() noexcept(false) { #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) - detail::NCCL_CHECK(ncclGroupEnd()); + if (!comm_nonblocking_) { + detail::NCCL_CHECK(ncclGroupEnd()); + } else { + detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_); + } #endif #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) (c10::cuda::getFreeMutex())->unlock(); @@ -677,7 +805,11 @@ void all2all_single_equal_split( ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); } } +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); +#else + NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); +#endif #endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); @@ -730,7 +862,11 @@ void all2all_single_unequal_split( stream)); } } +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); +#else + NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); +#endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif @@ -773,7 +909,11 @@ void all2all( stream.stream())); } } +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); +#else + NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); +#endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif @@ -791,6 +931,7 @@ void send( #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 7) using namespace torch::cuda::nccl::detail; +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclSend( input.data_ptr(), input.numel(), @@ -798,6 +939,17 @@ void send( dst, to_nccl_comm(comm), stream.stream())); +#else + NCCL_CHECK_TIMEOUT( + ncclSend( + input.data_ptr(), + input.numel(), + to_nccl_data_type(input), + dst, + to_nccl_comm(comm), + stream.stream()), + comm); +#endif #else AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); #endif @@ -815,6 +967,7 @@ void recv( #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 7) using namespace torch::cuda::nccl::detail; +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclRecv( output.data_ptr(), output.numel(), @@ -822,6 +975,17 @@ void recv( src, to_nccl_comm(comm), stream.stream())); +#else + NCCL_CHECK_TIMEOUT( + ncclRecv( + output.data_ptr(), + output.numel(), + to_nccl_data_type(output), + src, + to_nccl_comm(comm), + stream.stream()), + comm); +#endif #else AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); #endif @@ -865,7 +1029,11 @@ void gather( } else { NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream)); } +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); +#else + NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); +#endif #else AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0"); @@ -888,9 +1056,13 @@ void scatter( auto comm = to_nccl_comm(_comm); int numranks, cur_rank; +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); - +#else + NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm); + NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm); +#endif NCCL_CHECK(ncclGroupStart()); if (cur_rank == root) { for (const auto r : c10::irange(numranks)) { @@ -910,8 +1082,11 @@ void scatter( auto* recvbuff = reinterpret_cast(outputs.data_ptr()); NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream)); } +#ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclGroupEnd()); - +#else + NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); +#endif #else AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0"); #endif diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index f9f4fa8b1353..7640c911c307 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -46,7 +46,8 @@ enum class ncclResult { InternalError = 3, InvalidArgument = 4, InvalidUsage = 5, - NumResults = 6 + NumResults = 6, + InProgress = 7 }; /* Reduction operation selector */ @@ -77,7 +78,10 @@ enum class ncclDataType { // manages group and lock lifetimes. struct AutoNcclGroup { AutoNcclGroup(); + AutoNcclGroup(std::vector& comms, bool comm_nonblocking); ~AutoNcclGroup() noexcept(false); + std::vector comms_; + bool comm_nonblocking_; }; // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers. diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index a5724062b4a3..fbceb81cf932 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,6 +1,7 @@ #include #include +#include #ifdef USE_C10D_NCCL @@ -52,6 +53,35 @@ std::string getNcclVersion() { return versionString; } +bool nccl_use_nonblocking() { + static bool nccl_use_nonblocking_ = + c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true; + if (nccl_use_nonblocking_) { + TORCH_WARN("Using experimental non-blocking NCCL communicator."); + } + return nccl_use_nonblocking_; +} + +int _parse_nccl_nonblocking_timeout() { + const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); + int timeout = -1; + if (val) { + const std::string config(val); + timeout = std::stoi(config); + if (!nccl_use_nonblocking() && timeout > 0) { + TORCH_WARN( + "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); + timeout = -1; + } + } + return timeout; +} + +int nccl_nonblocking_timeout() { + static int timeout = _parse_nccl_nonblocking_timeout(); + return timeout; +} + std::string ncclGetErrorWithVersion(ncclResult_t error) { return std::string(ncclGetErrorString(error)) + ", NCCL version " + getNcclVersion(); diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 9f45ec61e09b..0e0b98cd4870 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -12,6 +12,11 @@ #include #include +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 14) +#define NCCL_HAS_COMM_NONBLOCKING +#endif + // ncclGetLastError() is enabled only for NCCL versions 2.13+ // ncclRemoteError only exists in NCCL versions 2.13+ #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -59,6 +64,60 @@ } \ } while (0) +// Macro to throw on a non-successful NCCL return value, non-blocking. +#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \ + ncclResult_t result = cmd; \ + auto startTimepoint = std::chrono::steady_clock::now(); \ + while (result == ncclInProgress) { \ + if (nccl_nonblocking_timeout() > 0) { \ + auto currentTimepoint = std::chrono::steady_clock::now(); \ + auto timeElapsed = std::chrono::duration_cast(currentTimepoint - startTimepoint).count(); \ + if (timeElapsed > nccl_nonblocking_timeout()) { \ + std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ + "\n" + getNcclErrorDetailStr(result, failureReason); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } \ + } \ + ncclCommGetAsyncError(comm, &result); \ + } \ + if (result != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ + "\n" + getNcclErrorDetailStr(result, failureReason); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } + +#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \ + ncclResult_t state = cmd; \ + auto startTimepoint = std::chrono::steady_clock::now(); \ + if (state == ncclInProgress) { \ + for (const auto i : c10::irange(comms_.size())) { \ + do { \ + if (nccl_nonblocking_timeout() > 0) { \ + auto currentTimepoint = std::chrono::steady_clock::now(); \ + auto timeElapsed = std::chrono::duration_cast(currentTimepoint - startTimepoint).count(); \ + if (timeElapsed > nccl_nonblocking_timeout()) { \ + std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ + "\n" + getNcclErrorDetailStr(state, failureReason); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } \ + } \ + ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \ + } while (state == ncclInProgress); \ + if (state != ncclSuccess) { \ + break; /* fall through to failed case */ \ + } \ + } \ + } \ + if (state != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ + "\n" + getNcclErrorDetailStr(state, failureReason); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } + // Macro to print and abort on a non-successful NCCL return value. #define C10D_NCCL_ASSERT(cmd) \ do { \ @@ -79,6 +138,8 @@ namespace c10d { std::string getNcclVersion(); std::string ncclGetErrorWithVersion(ncclResult_t error); +bool nccl_use_nonblocking(); +int nccl_nonblocking_timeout(); // Provides additional detail into NCCL error codes based on when these are // thrown in the NCCL codebase. @@ -118,8 +179,17 @@ class NCCLComm { int rank, ncclUniqueId commId) { auto comm = std::make_shared(); +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt); +#else + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; + if (nccl_use_nonblocking()) { + config.blocking = 0; + } + C10D_NCCL_CHECK_TIMEOUT( + ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); +#endif comm->ncclId_ = commId; comm->rank_ = rank; return comm; @@ -165,8 +235,12 @@ class NCCLComm { // Set true failure reason if provided by ProcessGroupNCCL (e.g. work // timeout) commFailureReason_ = commFailureReason; - +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); +#else + C10D_NCCL_CHECK_TIMEOUT( + ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); +#endif aborted_ = true; ncclComm_ = nullptr; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index bd095115e681..a506bafa76f9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1120,6 +1120,7 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // the following for loop. for (const auto i : c10::irange(ncclActiveGroupCounter_)) { (void)i; + // comms have not been initiated yet, so can only check in blocking-way C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); } @@ -1156,7 +1157,15 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( } // [Note 2 ] +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); +#else + if (!nccl_use_nonblocking()) { + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms, c10::nullopt); + } +#endif // At this point NCCL should have been initialized, hence we can accurately // get the env value even if NCCL sets it by reading from nccl.conf file @@ -1387,7 +1396,19 @@ void ProcessGroupNCCL::startCoalescing() { void ProcessGroupNCCL::endCoalescing( std::vector>& reqs) { - groupEnd(); + if (!nccl_use_nonblocking()) { + groupEnd(); + } else { + std::vector> ncclComms_; + for (const auto& req : reqs) { + auto ncclWork = static_cast(req.get()); + ncclComms_.insert( + ncclComms_.end(), + ncclWork->ncclComms_.begin(), + ncclWork->ncclComms_.end()); + } + groupEndNonblocking(ncclComms_); + } if (reqs.size() != coalescedDevices_.size()) { TORCH_CHECK(false, "Number of requests do not match number of collectives"); } @@ -1478,8 +1499,17 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( pre(ncclStreams, work); + std::vector comms_; + if (nccl_use_nonblocking()) { + for (const auto i : c10::irange(inputs.size())) { + decltype(i) stream_comm_i = (inputs_same_dev ? 0 : i); + comms_.push_back((void*)ncclComms[stream_comm_i]->getNcclComm()); + } + } + { - torch::cuda::nccl::AutoNcclGroup nccl_group_guard; + torch::cuda::nccl::AutoNcclGroup nccl_group_guard( + comms_, nccl_use_nonblocking()); for (const auto i : c10::irange(inputs.size())) { if (!inputs_same_dev || (inputs_same_dev && i == 0)) { gpuGuard.set_index(devices[i].index()); @@ -1499,12 +1529,18 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::cuda::CUDACachingAllocator::recordStream( inputs[i].storage().data_ptr(), ncclStream); } +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream), ncclComm->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream), + ncclComm->getNcclComm(), + ncclComm->getNcclCommFailureReason()); +#endif } } - post(ncclStreams); // End event should only be recorded after the ncclGroupEnd() @@ -1634,17 +1670,34 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( tensors[i].storage().data_ptr(), ncclStream); } + std::vector comms_; + if (nccl_use_nonblocking()) { + for (const auto i : c10::irange(tensors.size())) { + comms_.push_back((void*)ncclComms[i]->getNcclComm()); + } + } { - torch::cuda::nccl::AutoNcclGroup nccl_group_guard; + torch::cuda::nccl::AutoNcclGroup nccl_group_guard( + comms_, nccl_use_nonblocking()); for (const auto i : c10::irange(tensors.size())) { gpuGuard.set_index(devices[i].index()); at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( fn(tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank), ncclComms[i]->getNcclCommFailureReason()); +#else + C10D_NCCL_CHECK_TIMEOUT( + fn(tensors[i], + ncclComms[i]->getNcclComm(), + ncclStream, + p2pTargetRank), + ncclComms[i]->getNcclComm(), + ncclComms[i]->getNcclCommFailureReason()); +#endif } } @@ -2610,7 +2663,38 @@ void ProcessGroupNCCL::groupStart() { void ProcessGroupNCCL::groupEnd() { #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); +#else + if (!nccl_use_nonblocking()) { + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + } else { + TORCH_WARN( + "ProcessGroupNCCL::groupEnd() called in nonblocking communicator mode without involved communicators specified; gathering all mapped communicators..."); + std::unique_lock lock(mutex_); + std::vector> ncclComms_; + for (auto& it : devNCCLCommMap_) { + ncclComms_.insert(ncclComms_.end(), it.second.begin(), it.second.end()); + } + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms_, c10::nullopt); + } +#endif +#endif + --ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEndNonblocking( + std::vector> comms) { +#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); +#else + if (!nccl_use_nonblocking()) { + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comms, c10::nullopt); + } +#endif #endif --ncclActiveGroupCounter_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 1fdabd1731d3..5378e69cabf8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -404,9 +404,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { int srcRank, int tag) override; - static void groupStart(); + void groupStart(); - static void groupEnd(); + void groupEnd(); + + void groupEndNonblocking(std::vector> comms); // Unsupported Ops c10::intrusive_ptr gather( diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 85b707409c02..72127ed232de 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1970,6 +1970,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). }, py::arg("abort_reason") = py::none(), py::call_guard()) + .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart) + .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd) .def_property_readonly( "options", &::c10d::ProcessGroupNCCL::getOptions) .def_property_readonly( @@ -1999,10 +2001,6 @@ Example:: .def_readwrite( "is_high_priority_stream", &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); - processGroupNCCL.def_static( - "_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); }); - processGroupNCCL.def_static( - "_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); }); #endif #ifdef USE_C10D_MPI