mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NCCL] Add experimental Nonblocking NCCL Fault Tolerance/Checking (#95715)
Support for nonblocking NCCL communicators/fault tolerance/checking which was added in 2.14 as an experimental feature. Enabled via the environment variable: ``` TORCH_NCCL_USE_COMM_NONBLOCKING=1 ``` CC @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/95715 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
09458a2bf1
commit
a33eac3988
@ -408,10 +408,8 @@ class ProcessGroupNCCL(ProcessGroup):
|
||||
size: int,
|
||||
timeout: timedelta,
|
||||
): ...
|
||||
@staticmethod
|
||||
def _group_start() -> None: ...
|
||||
@staticmethod
|
||||
def _group_end() -> None: ...
|
||||
def _group_start(self) -> None: ...
|
||||
def _group_end(self) -> None: ...
|
||||
|
||||
class ProcessGroupUCC(ProcessGroup):
|
||||
def __init__(
|
||||
|
@ -16,6 +16,11 @@
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#if !defined(USE_ROCM) && \
|
||||
((NCCL_MACJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14)))
|
||||
#define NCCL_HAS_COMM_NONBLOCKING 1
|
||||
#endif
|
||||
|
||||
ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
|
||||
return reinterpret_cast<ncclComm_t*>(var);
|
||||
}
|
||||
@ -44,6 +49,10 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
|
||||
return ncclResult_t::ncclInvalidUsage;
|
||||
case torch::cuda::nccl::ncclResult::NumResults:
|
||||
return ncclResult_t::ncclNumResults;
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
case torch::cuda::nccl::ncclResult::InProgress:
|
||||
return ncclResult_t::ncclInProgress;
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("Unconvertible NCCL type");
|
||||
}
|
||||
@ -65,6 +74,10 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
|
||||
return torch::cuda::nccl::ncclResult::InvalidUsage;
|
||||
case ncclNumResults:
|
||||
return torch::cuda::nccl::ncclResult::NumResults;
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
case ncclInProgress:
|
||||
return torch::cuda::nccl::ncclResult::InProgress;
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("Unconvertible NCCL type");
|
||||
}
|
||||
@ -123,6 +136,105 @@ static inline void NCCL_CHECK(ncclResult_t result) {
|
||||
NCCL_CHECK(from_nccl_result(result));
|
||||
}
|
||||
|
||||
// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
|
||||
bool nccl_use_nonblocking() {
|
||||
static bool nccl_use_nonblocking_ =
|
||||
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
|
||||
if (nccl_use_nonblocking_) {
|
||||
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
|
||||
}
|
||||
return nccl_use_nonblocking_;
|
||||
}
|
||||
|
||||
static int _parse_nccl_nonblocking_timeout() {
|
||||
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
|
||||
int timeout = -1;
|
||||
if (val) {
|
||||
const std::string config(val);
|
||||
timeout = std::stoi(config);
|
||||
if (!nccl_use_nonblocking() && timeout > 0) {
|
||||
TORCH_WARN(
|
||||
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
|
||||
timeout = -1;
|
||||
}
|
||||
}
|
||||
return timeout;
|
||||
}
|
||||
|
||||
static int nccl_nonblocking_timeout() {
|
||||
static int timeout = _parse_nccl_nonblocking_timeout();
|
||||
return timeout;
|
||||
}
|
||||
|
||||
static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
ncclResult_t result = to_nccl_result(status);
|
||||
auto startTimepoint = std::chrono::steady_clock::now();
|
||||
while (result == ncclInProgress) {
|
||||
if (nccl_nonblocking_timeout() > 0) {
|
||||
auto currentTimepoint = std::chrono::steady_clock::now();
|
||||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) {
|
||||
throw std::runtime_error("NCCL timeout.");
|
||||
}
|
||||
}
|
||||
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
|
||||
}
|
||||
if (result != ncclSuccess) {
|
||||
throw_nccl_error(from_nccl_result(result));
|
||||
}
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) {
|
||||
NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm);
|
||||
}
|
||||
|
||||
static inline void NCCL_CHECK_TIMEOUT(
|
||||
ncclResult status,
|
||||
std::vector<ncclComm_t>& comms) {
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
ncclResult_t result = to_nccl_result(status);
|
||||
auto startTimepoint = std::chrono::steady_clock::now();
|
||||
if (result == ncclInProgress) {
|
||||
for (const auto i : c10::irange(comms.size())) {
|
||||
do {
|
||||
if (nccl_nonblocking_timeout() > 0) {
|
||||
auto currentTimepoint = std::chrono::steady_clock::now();
|
||||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) {
|
||||
throw std::runtime_error("NCCL timeout.");
|
||||
}
|
||||
}
|
||||
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
|
||||
} while (result == ncclInProgress);
|
||||
if (result != ncclSuccess) {
|
||||
break; /* fall through to failed case */
|
||||
}
|
||||
}
|
||||
}
|
||||
if (result != ncclSuccess) {
|
||||
throw_nccl_error(from_nccl_result(result));
|
||||
}
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline void NCCL_CHECK_TIMEOUT(
|
||||
ncclResult_t result,
|
||||
std::vector<ncclComm_t>& comms) {
|
||||
NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms);
|
||||
}
|
||||
|
||||
void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
|
||||
std::ostringstream err;
|
||||
err << "NCCL Error " << static_cast<int>(status) << ": "
|
||||
@ -308,9 +420,25 @@ AutoNcclGroup::AutoNcclGroup() {
|
||||
#endif
|
||||
}
|
||||
|
||||
AutoNcclGroup::AutoNcclGroup(
|
||||
std::vector<ncclComm_t>& comms,
|
||||
bool comm_nonblocking) {
|
||||
(c10::cuda::getFreeMutex())->lock();
|
||||
// TODO(eqy): can we make comms_ reference?
|
||||
comms_ = comms;
|
||||
comm_nonblocking_ = comm_nonblocking;
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupStart());
|
||||
#endif
|
||||
}
|
||||
|
||||
AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupEnd());
|
||||
if (!comm_nonblocking_) {
|
||||
detail::NCCL_CHECK(ncclGroupEnd());
|
||||
} else {
|
||||
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_);
|
||||
}
|
||||
#endif
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
|
||||
(c10::cuda::getFreeMutex())->unlock();
|
||||
@ -677,7 +805,11 @@ void all2all_single_equal_split(
|
||||
ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
|
||||
}
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
|
||||
@ -730,7 +862,11 @@ void all2all_single_unequal_split(
|
||||
stream));
|
||||
}
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
@ -773,7 +909,11 @@ void all2all(
|
||||
stream.stream()));
|
||||
}
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
@ -791,6 +931,7 @@ void send(
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
(NCCL_MINOR >= 7)
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclSend(
|
||||
input.data_ptr(),
|
||||
input.numel(),
|
||||
@ -798,6 +939,17 @@ void send(
|
||||
dst,
|
||||
to_nccl_comm(comm),
|
||||
stream.stream()));
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(
|
||||
ncclSend(
|
||||
input.data_ptr(),
|
||||
input.numel(),
|
||||
to_nccl_data_type(input),
|
||||
dst,
|
||||
to_nccl_comm(comm),
|
||||
stream.stream()),
|
||||
comm);
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
@ -815,6 +967,7 @@ void recv(
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
(NCCL_MINOR >= 7)
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclRecv(
|
||||
output.data_ptr(),
|
||||
output.numel(),
|
||||
@ -822,6 +975,17 @@ void recv(
|
||||
src,
|
||||
to_nccl_comm(comm),
|
||||
stream.stream()));
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(
|
||||
ncclRecv(
|
||||
output.data_ptr(),
|
||||
output.numel(),
|
||||
to_nccl_data_type(output),
|
||||
src,
|
||||
to_nccl_comm(comm),
|
||||
stream.stream()),
|
||||
comm);
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
@ -865,7 +1029,11 @@ void gather(
|
||||
} else {
|
||||
NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
|
||||
#endif
|
||||
|
||||
#else
|
||||
AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
|
||||
@ -888,9 +1056,13 @@ void scatter(
|
||||
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
int numranks, cur_rank;
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclCommCount(comm, &numranks));
|
||||
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
|
||||
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm);
|
||||
NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm);
|
||||
#endif
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
if (cur_rank == root) {
|
||||
for (const auto r : c10::irange(numranks)) {
|
||||
@ -910,8 +1082,11 @@ void scatter(
|
||||
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
|
||||
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
#else
|
||||
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
|
@ -46,7 +46,8 @@ enum class ncclResult {
|
||||
InternalError = 3,
|
||||
InvalidArgument = 4,
|
||||
InvalidUsage = 5,
|
||||
NumResults = 6
|
||||
NumResults = 6,
|
||||
InProgress = 7
|
||||
};
|
||||
|
||||
/* Reduction operation selector */
|
||||
@ -77,7 +78,10 @@ enum class ncclDataType {
|
||||
// manages group and lock lifetimes.
|
||||
struct AutoNcclGroup {
|
||||
AutoNcclGroup();
|
||||
AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking);
|
||||
~AutoNcclGroup() noexcept(false);
|
||||
std::vector<ncclComm_t> comms_;
|
||||
bool comm_nonblocking_;
|
||||
};
|
||||
|
||||
// NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <c10/util/env.h>
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
@ -52,6 +53,35 @@ std::string getNcclVersion() {
|
||||
return versionString;
|
||||
}
|
||||
|
||||
bool nccl_use_nonblocking() {
|
||||
static bool nccl_use_nonblocking_ =
|
||||
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
|
||||
if (nccl_use_nonblocking_) {
|
||||
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
|
||||
}
|
||||
return nccl_use_nonblocking_;
|
||||
}
|
||||
|
||||
int _parse_nccl_nonblocking_timeout() {
|
||||
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
|
||||
int timeout = -1;
|
||||
if (val) {
|
||||
const std::string config(val);
|
||||
timeout = std::stoi(config);
|
||||
if (!nccl_use_nonblocking() && timeout > 0) {
|
||||
TORCH_WARN(
|
||||
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
|
||||
timeout = -1;
|
||||
}
|
||||
}
|
||||
return timeout;
|
||||
}
|
||||
|
||||
int nccl_nonblocking_timeout() {
|
||||
static int timeout = _parse_nccl_nonblocking_timeout();
|
||||
return timeout;
|
||||
}
|
||||
|
||||
std::string ncclGetErrorWithVersion(ncclResult_t error) {
|
||||
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
|
||||
getNcclVersion();
|
||||
|
@ -12,6 +12,11 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
(NCCL_MINOR >= 14)
|
||||
#define NCCL_HAS_COMM_NONBLOCKING
|
||||
#endif
|
||||
|
||||
// ncclGetLastError() is enabled only for NCCL versions 2.13+
|
||||
// ncclRemoteError only exists in NCCL versions 2.13+
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
@ -59,6 +64,60 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value, non-blocking.
|
||||
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
|
||||
ncclResult_t result = cmd; \
|
||||
auto startTimepoint = std::chrono::steady_clock::now(); \
|
||||
while (result == ncclInProgress) { \
|
||||
if (nccl_nonblocking_timeout() > 0) { \
|
||||
auto currentTimepoint = std::chrono::steady_clock::now(); \
|
||||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(currentTimepoint - startTimepoint).count(); \
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) { \
|
||||
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
|
||||
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
|
||||
"\n" + getNcclErrorDetailStr(result, failureReason); \
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
} \
|
||||
} \
|
||||
ncclCommGetAsyncError(comm, &result); \
|
||||
} \
|
||||
if (result != ncclSuccess) { \
|
||||
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
||||
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
|
||||
"\n" + getNcclErrorDetailStr(result, failureReason); \
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
}
|
||||
|
||||
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
|
||||
ncclResult_t state = cmd; \
|
||||
auto startTimepoint = std::chrono::steady_clock::now(); \
|
||||
if (state == ncclInProgress) { \
|
||||
for (const auto i : c10::irange(comms_.size())) { \
|
||||
do { \
|
||||
if (nccl_nonblocking_timeout() > 0) { \
|
||||
auto currentTimepoint = std::chrono::steady_clock::now(); \
|
||||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(currentTimepoint - startTimepoint).count(); \
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) { \
|
||||
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
|
||||
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
|
||||
"\n" + getNcclErrorDetailStr(state, failureReason); \
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
} \
|
||||
} \
|
||||
ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
|
||||
} while (state == ncclInProgress); \
|
||||
if (state != ncclSuccess) { \
|
||||
break; /* fall through to failed case */ \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
if (state != ncclSuccess) { \
|
||||
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
||||
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
|
||||
"\n" + getNcclErrorDetailStr(state, failureReason); \
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
}
|
||||
|
||||
// Macro to print and abort on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_ASSERT(cmd) \
|
||||
do { \
|
||||
@ -79,6 +138,8 @@ namespace c10d {
|
||||
|
||||
std::string getNcclVersion();
|
||||
std::string ncclGetErrorWithVersion(ncclResult_t error);
|
||||
bool nccl_use_nonblocking();
|
||||
int nccl_nonblocking_timeout();
|
||||
|
||||
// Provides additional detail into NCCL error codes based on when these are
|
||||
// thrown in the NCCL codebase.
|
||||
@ -118,8 +179,17 @@ class NCCLComm {
|
||||
int rank,
|
||||
ncclUniqueId commId) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
|
||||
#else
|
||||
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
|
||||
if (nccl_use_nonblocking()) {
|
||||
config.blocking = 0;
|
||||
}
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
|
||||
#endif
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
return comm;
|
||||
@ -165,8 +235,12 @@ class NCCLComm {
|
||||
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
|
||||
// timeout)
|
||||
commFailureReason_ = commFailureReason;
|
||||
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
|
||||
#else
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
|
||||
#endif
|
||||
aborted_ = true;
|
||||
ncclComm_ = nullptr;
|
||||
|
||||
|
@ -1120,6 +1120,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
// the following for loop.
|
||||
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
|
||||
(void)i;
|
||||
// comms have not been initiated yet, so can only check in blocking-way
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
}
|
||||
|
||||
@ -1156,7 +1157,15 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
}
|
||||
|
||||
// [Note 2 ]
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
#else
|
||||
if (!nccl_use_nonblocking()) {
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms, c10::nullopt);
|
||||
}
|
||||
#endif
|
||||
|
||||
// At this point NCCL should have been initialized, hence we can accurately
|
||||
// get the env value even if NCCL sets it by reading from nccl.conf file
|
||||
@ -1387,7 +1396,19 @@ void ProcessGroupNCCL::startCoalescing() {
|
||||
|
||||
void ProcessGroupNCCL::endCoalescing(
|
||||
std::vector<c10::intrusive_ptr<Work>>& reqs) {
|
||||
groupEnd();
|
||||
if (!nccl_use_nonblocking()) {
|
||||
groupEnd();
|
||||
} else {
|
||||
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
|
||||
for (const auto& req : reqs) {
|
||||
auto ncclWork = static_cast<ProcessGroupNCCL::WorkNCCL*>(req.get());
|
||||
ncclComms_.insert(
|
||||
ncclComms_.end(),
|
||||
ncclWork->ncclComms_.begin(),
|
||||
ncclWork->ncclComms_.end());
|
||||
}
|
||||
groupEndNonblocking(ncclComms_);
|
||||
}
|
||||
if (reqs.size() != coalescedDevices_.size()) {
|
||||
TORCH_CHECK(false, "Number of requests do not match number of collectives");
|
||||
}
|
||||
@ -1478,8 +1499,17 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
|
||||
pre(ncclStreams, work);
|
||||
|
||||
std::vector<void*> comms_;
|
||||
if (nccl_use_nonblocking()) {
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
decltype(i) stream_comm_i = (inputs_same_dev ? 0 : i);
|
||||
comms_.push_back((void*)ncclComms[stream_comm_i]->getNcclComm());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard;
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
|
||||
comms_, nccl_use_nonblocking());
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
if (!inputs_same_dev || (inputs_same_dev && i == 0)) {
|
||||
gpuGuard.set_index(devices[i].index());
|
||||
@ -1499,12 +1529,18 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
inputs[i].storage().data_ptr(), ncclStream);
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(
|
||||
fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream),
|
||||
ncclComm->getNcclCommFailureReason());
|
||||
#else
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream),
|
||||
ncclComm->getNcclComm(),
|
||||
ncclComm->getNcclCommFailureReason());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
post(ncclStreams);
|
||||
|
||||
// End event should only be recorded after the ncclGroupEnd()
|
||||
@ -1634,17 +1670,34 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
tensors[i].storage().data_ptr(), ncclStream);
|
||||
}
|
||||
|
||||
std::vector<void*> comms_;
|
||||
if (nccl_use_nonblocking()) {
|
||||
for (const auto i : c10::irange(tensors.size())) {
|
||||
comms_.push_back((void*)ncclComms[i]->getNcclComm());
|
||||
}
|
||||
}
|
||||
{
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard;
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
|
||||
comms_, nccl_use_nonblocking());
|
||||
for (const auto i : c10::irange(tensors.size())) {
|
||||
gpuGuard.set_index(devices[i].index());
|
||||
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(
|
||||
fn(tensors[i],
|
||||
ncclComms[i]->getNcclComm(),
|
||||
ncclStream,
|
||||
p2pTargetRank),
|
||||
ncclComms[i]->getNcclCommFailureReason());
|
||||
#else
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
fn(tensors[i],
|
||||
ncclComms[i]->getNcclComm(),
|
||||
ncclStream,
|
||||
p2pTargetRank),
|
||||
ncclComms[i]->getNcclComm(),
|
||||
ncclComms[i]->getNcclCommFailureReason());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -2610,7 +2663,38 @@ void ProcessGroupNCCL::groupStart() {
|
||||
|
||||
void ProcessGroupNCCL::groupEnd() {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
#else
|
||||
if (!nccl_use_nonblocking()) {
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
} else {
|
||||
TORCH_WARN(
|
||||
"ProcessGroupNCCL::groupEnd() called in nonblocking communicator mode without involved communicators specified; gathering all mapped communicators...");
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
|
||||
for (auto& it : devNCCLCommMap_) {
|
||||
ncclComms_.insert(ncclComms_.end(), it.second.begin(), it.second.end());
|
||||
}
|
||||
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms_, c10::nullopt);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
--ncclActiveGroupCounter_;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::groupEndNonblocking(
|
||||
std::vector<std::shared_ptr<NCCLComm>> comms) {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
#else
|
||||
if (!nccl_use_nonblocking()) {
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comms, c10::nullopt);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
--ncclActiveGroupCounter_;
|
||||
}
|
||||
|
@ -404,9 +404,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
int srcRank,
|
||||
int tag) override;
|
||||
|
||||
static void groupStart();
|
||||
void groupStart();
|
||||
|
||||
static void groupEnd();
|
||||
void groupEnd();
|
||||
|
||||
void groupEndNonblocking(std::vector<std::shared_ptr<NCCLComm>> comms);
|
||||
|
||||
// Unsupported Ops
|
||||
c10::intrusive_ptr<Work> gather(
|
||||
|
@ -1970,6 +1970,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
},
|
||||
py::arg("abort_reason") = py::none(),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
|
||||
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
|
||||
.def_property_readonly(
|
||||
"options", &::c10d::ProcessGroupNCCL::getOptions)
|
||||
.def_property_readonly(
|
||||
@ -1999,10 +2001,6 @@ Example::
|
||||
.def_readwrite(
|
||||
"is_high_priority_stream",
|
||||
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
|
||||
processGroupNCCL.def_static(
|
||||
"_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); });
|
||||
processGroupNCCL.def_static(
|
||||
"_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); });
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_MPI
|
||||
|
Reference in New Issue
Block a user