[NCCL] Add experimental Nonblocking NCCL Fault Tolerance/Checking (#95715)

Support for nonblocking NCCL communicators/fault tolerance/checking which was added in 2.14 as an experimental feature.
Enabled via the environment variable:
```
TORCH_NCCL_USE_COMM_NONBLOCKING=1
```

CC @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95715
Approved by: https://github.com/kwen2501
This commit is contained in:
Eddie Yan
2023-04-12 18:33:10 +00:00
committed by PyTorch MergeBot
parent 09458a2bf1
commit a33eac3988
8 changed files with 384 additions and 19 deletions

View File

@ -408,10 +408,8 @@ class ProcessGroupNCCL(ProcessGroup):
size: int,
timeout: timedelta,
): ...
@staticmethod
def _group_start() -> None: ...
@staticmethod
def _group_end() -> None: ...
def _group_start(self) -> None: ...
def _group_end(self) -> None: ...
class ProcessGroupUCC(ProcessGroup):
def __init__(

View File

@ -16,6 +16,11 @@
#include <type_traits>
#include <unordered_map>
#if !defined(USE_ROCM) && \
((NCCL_MACJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14)))
#define NCCL_HAS_COMM_NONBLOCKING 1
#endif
ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}
@ -44,6 +49,10 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
return ncclResult_t::ncclInvalidUsage;
case torch::cuda::nccl::ncclResult::NumResults:
return ncclResult_t::ncclNumResults;
#ifdef NCCL_HAS_COMM_NONBLOCKING
case torch::cuda::nccl::ncclResult::InProgress:
return ncclResult_t::ncclInProgress;
#endif
default:
throw std::runtime_error("Unconvertible NCCL type");
}
@ -65,6 +74,10 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
return torch::cuda::nccl::ncclResult::InvalidUsage;
case ncclNumResults:
return torch::cuda::nccl::ncclResult::NumResults;
#ifdef NCCL_HAS_COMM_NONBLOCKING
case ncclInProgress:
return torch::cuda::nccl::ncclResult::InProgress;
#endif
default:
throw std::runtime_error("Unconvertible NCCL type");
}
@ -123,6 +136,105 @@ static inline void NCCL_CHECK(ncclResult_t result) {
NCCL_CHECK(from_nccl_result(result));
}
// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
}
return nccl_use_nonblocking_;
}
static int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
}
}
return timeout;
}
static int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}
static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclResult_t result = to_nccl_result(status);
auto startTimepoint = std::chrono::steady_clock::now();
while (result == ncclInProgress) {
if (nccl_nonblocking_timeout() > 0) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
}
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
}
if (result != ncclSuccess) {
throw_nccl_error(from_nccl_result(result));
}
#else
TORCH_INTERNAL_ASSERT(
false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
#endif
}
static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) {
NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm);
}
static inline void NCCL_CHECK_TIMEOUT(
ncclResult status,
std::vector<ncclComm_t>& comms) {
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclResult_t result = to_nccl_result(status);
auto startTimepoint = std::chrono::steady_clock::now();
if (result == ncclInProgress) {
for (const auto i : c10::irange(comms.size())) {
do {
if (nccl_nonblocking_timeout() > 0) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
}
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
} while (result == ncclInProgress);
if (result != ncclSuccess) {
break; /* fall through to failed case */
}
}
}
if (result != ncclSuccess) {
throw_nccl_error(from_nccl_result(result));
}
#else
TORCH_INTERNAL_ASSERT(
false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
#endif
}
static inline void NCCL_CHECK_TIMEOUT(
ncclResult_t result,
std::vector<ncclComm_t>& comms) {
NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms);
}
void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
std::ostringstream err;
err << "NCCL Error " << static_cast<int>(status) << ": "
@ -308,9 +420,25 @@ AutoNcclGroup::AutoNcclGroup() {
#endif
}
AutoNcclGroup::AutoNcclGroup(
std::vector<ncclComm_t>& comms,
bool comm_nonblocking) {
(c10::cuda::getFreeMutex())->lock();
// TODO(eqy): can we make comms_ reference?
comms_ = comms;
comm_nonblocking_ = comm_nonblocking;
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
detail::NCCL_CHECK(ncclGroupStart());
#endif
}
AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
detail::NCCL_CHECK(ncclGroupEnd());
if (!comm_nonblocking_) {
detail::NCCL_CHECK(ncclGroupEnd());
} else {
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_);
}
#endif
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
(c10::cuda::getFreeMutex())->unlock();
@ -677,7 +805,11 @@ void all2all_single_equal_split(
ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
@ -730,7 +862,11 @@ void all2all_single_unequal_split(
stream));
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
@ -773,7 +909,11 @@ void all2all(
stream.stream()));
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
@ -791,6 +931,7 @@ void send(
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
@ -798,6 +939,17 @@ void send(
dst,
to_nccl_comm(comm),
stream.stream()));
#else
NCCL_CHECK_TIMEOUT(
ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
to_nccl_comm(comm),
stream.stream()),
comm);
#endif
#else
AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
#endif
@ -815,6 +967,7 @@ void recv(
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
@ -822,6 +975,17 @@ void recv(
src,
to_nccl_comm(comm),
stream.stream()));
#else
NCCL_CHECK_TIMEOUT(
ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
to_nccl_comm(comm),
stream.stream()),
comm);
#endif
#else
AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
#endif
@ -865,7 +1029,11 @@ void gather(
} else {
NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
@ -888,9 +1056,13 @@ void scatter(
auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
#else
NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm);
NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm);
#endif
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
@ -910,8 +1082,11 @@ void scatter(
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
#endif

View File

@ -46,7 +46,8 @@ enum class ncclResult {
InternalError = 3,
InvalidArgument = 4,
InvalidUsage = 5,
NumResults = 6
NumResults = 6,
InProgress = 7
};
/* Reduction operation selector */
@ -77,7 +78,10 @@ enum class ncclDataType {
// manages group and lock lifetimes.
struct AutoNcclGroup {
AutoNcclGroup();
AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking);
~AutoNcclGroup() noexcept(false);
std::vector<ncclComm_t> comms_;
bool comm_nonblocking_;
};
// NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.

View File

@ -1,6 +1,7 @@
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <c10/util/CallOnce.h>
#include <c10/util/env.h>
#ifdef USE_C10D_NCCL
@ -52,6 +53,35 @@ std::string getNcclVersion() {
return versionString;
}
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
}
return nccl_use_nonblocking_;
}
int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
}
}
return timeout;
}
int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}
std::string ncclGetErrorWithVersion(ncclResult_t error) {
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
getNcclVersion();

View File

@ -12,6 +12,11 @@
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 14)
#define NCCL_HAS_COMM_NONBLOCKING
#endif
// ncclGetLastError() is enabled only for NCCL versions 2.13+
// ncclRemoteError only exists in NCCL versions 2.13+
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
@ -59,6 +64,60 @@
} \
} while (0)
// Macro to throw on a non-successful NCCL return value, non-blocking.
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
ncclResult_t result = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
while (result == ncclInProgress) { \
if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(currentTimepoint - startTimepoint).count(); \
if (timeElapsed > nccl_nonblocking_timeout()) { \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
ncclCommGetAsyncError(comm, &result); \
} \
if (result != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
}
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
if (state == ncclInProgress) { \
for (const auto i : c10::irange(comms_.size())) { \
do { \
if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(currentTimepoint - startTimepoint).count(); \
if (timeElapsed > nccl_nonblocking_timeout()) { \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
} while (state == ncclInProgress); \
if (state != ncclSuccess) { \
break; /* fall through to failed case */ \
} \
} \
} \
if (state != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
}
// Macro to print and abort on a non-successful NCCL return value.
#define C10D_NCCL_ASSERT(cmd) \
do { \
@ -79,6 +138,8 @@ namespace c10d {
std::string getNcclVersion();
std::string ncclGetErrorWithVersion(ncclResult_t error);
bool nccl_use_nonblocking();
int nccl_nonblocking_timeout();
// Provides additional detail into NCCL error codes based on when these are
// thrown in the NCCL codebase.
@ -118,8 +179,17 @@ class NCCLComm {
int rank,
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
#else
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
if (nccl_use_nonblocking()) {
config.blocking = 0;
}
C10D_NCCL_CHECK_TIMEOUT(
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
#endif
comm->ncclId_ = commId;
comm->rank_ = rank;
return comm;
@ -165,8 +235,12 @@ class NCCLComm {
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
// timeout)
commFailureReason_ = commFailureReason;
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
#else
C10D_NCCL_CHECK_TIMEOUT(
::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
#endif
aborted_ = true;
ncclComm_ = nullptr;

View File

@ -1120,6 +1120,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
// the following for loop.
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
(void)i;
// comms have not been initiated yet, so can only check in blocking-way
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
}
@ -1156,7 +1157,15 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
}
// [Note 2 ]
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
#else
if (!nccl_use_nonblocking()) {
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
} else {
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms, c10::nullopt);
}
#endif
// At this point NCCL should have been initialized, hence we can accurately
// get the env value even if NCCL sets it by reading from nccl.conf file
@ -1387,7 +1396,19 @@ void ProcessGroupNCCL::startCoalescing() {
void ProcessGroupNCCL::endCoalescing(
std::vector<c10::intrusive_ptr<Work>>& reqs) {
groupEnd();
if (!nccl_use_nonblocking()) {
groupEnd();
} else {
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
for (const auto& req : reqs) {
auto ncclWork = static_cast<ProcessGroupNCCL::WorkNCCL*>(req.get());
ncclComms_.insert(
ncclComms_.end(),
ncclWork->ncclComms_.begin(),
ncclWork->ncclComms_.end());
}
groupEndNonblocking(ncclComms_);
}
if (reqs.size() != coalescedDevices_.size()) {
TORCH_CHECK(false, "Number of requests do not match number of collectives");
}
@ -1478,8 +1499,17 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
pre(ncclStreams, work);
std::vector<void*> comms_;
if (nccl_use_nonblocking()) {
for (const auto i : c10::irange(inputs.size())) {
decltype(i) stream_comm_i = (inputs_same_dev ? 0 : i);
comms_.push_back((void*)ncclComms[stream_comm_i]->getNcclComm());
}
}
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard;
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
comms_, nccl_use_nonblocking());
for (const auto i : c10::irange(inputs.size())) {
if (!inputs_same_dev || (inputs_same_dev && i == 0)) {
gpuGuard.set_index(devices[i].index());
@ -1499,12 +1529,18 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].storage().data_ptr(), ncclStream);
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream),
ncclComm->getNcclCommFailureReason());
#else
C10D_NCCL_CHECK_TIMEOUT(
fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream),
ncclComm->getNcclComm(),
ncclComm->getNcclCommFailureReason());
#endif
}
}
post(ncclStreams);
// End event should only be recorded after the ncclGroupEnd()
@ -1634,17 +1670,34 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
tensors[i].storage().data_ptr(), ncclStream);
}
std::vector<void*> comms_;
if (nccl_use_nonblocking()) {
for (const auto i : c10::irange(tensors.size())) {
comms_.push_back((void*)ncclComms[i]->getNcclComm());
}
}
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard;
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
comms_, nccl_use_nonblocking());
for (const auto i : c10::irange(tensors.size())) {
gpuGuard.set_index(devices[i].index());
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
fn(tensors[i],
ncclComms[i]->getNcclComm(),
ncclStream,
p2pTargetRank),
ncclComms[i]->getNcclCommFailureReason());
#else
C10D_NCCL_CHECK_TIMEOUT(
fn(tensors[i],
ncclComms[i]->getNcclComm(),
ncclStream,
p2pTargetRank),
ncclComms[i]->getNcclComm(),
ncclComms[i]->getNcclCommFailureReason());
#endif
}
}
@ -2610,7 +2663,38 @@ void ProcessGroupNCCL::groupStart() {
void ProcessGroupNCCL::groupEnd() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
#else
if (!nccl_use_nonblocking()) {
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
} else {
TORCH_WARN(
"ProcessGroupNCCL::groupEnd() called in nonblocking communicator mode without involved communicators specified; gathering all mapped communicators...");
std::unique_lock<std::mutex> lock(mutex_);
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
for (auto& it : devNCCLCommMap_) {
ncclComms_.insert(ncclComms_.end(), it.second.begin(), it.second.end());
}
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms_, c10::nullopt);
}
#endif
#endif
--ncclActiveGroupCounter_;
}
void ProcessGroupNCCL::groupEndNonblocking(
std::vector<std::shared_ptr<NCCLComm>> comms) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
#else
if (!nccl_use_nonblocking()) {
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
} else {
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comms, c10::nullopt);
}
#endif
#endif
--ncclActiveGroupCounter_;
}

View File

@ -404,9 +404,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
int srcRank,
int tag) override;
static void groupStart();
void groupStart();
static void groupEnd();
void groupEnd();
void groupEndNonblocking(std::vector<std::shared_ptr<NCCLComm>> comms);
// Unsupported Ops
c10::intrusive_ptr<Work> gather(

View File

@ -1970,6 +1970,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
},
py::arg("abort_reason") = py::none(),
py::call_guard<py::gil_scoped_release>())
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
.def_property_readonly(
"options", &::c10d::ProcessGroupNCCL::getOptions)
.def_property_readonly(
@ -1999,10 +2001,6 @@ Example::
.def_readwrite(
"is_high_priority_stream",
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
processGroupNCCL.def_static(
"_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); });
processGroupNCCL.def_static(
"_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); });
#endif
#ifdef USE_C10D_MPI