mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60543 Since now c10d is part of libtorch, it would also be nice if the sources lived all in one place. ghstack-source-id: 132306292 Test Plan: It builds Reviewed By: cbalioglu Differential Revision: D29062002 fbshipit-source-id: d9e1301e9d73e1643fa0f0119cd2d618f1ad52e6
208 lines
6.7 KiB
C++
208 lines
6.7 KiB
C++
#pragma once
|
|
|
|
#ifdef USE_C10D_NCCL
|
|
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
|
|
#include <memory>
|
|
#include <mutex>
|
|
|
|
#include <nccl.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace {
|
|
// Provides additional detail into NCCL error codes based on when these are
|
|
// thrown in the NCCL codebase.
|
|
const inline char* getNcclErrorDetailStr(ncclResult_t error) {
|
|
switch (error) {
|
|
case ncclUnhandledCudaError:
|
|
return "ncclUnhandledCudaError: Call to CUDA function failed.";
|
|
case ncclSystemError:
|
|
return "ncclSystemError: System call (socket, malloc, munmap, etc) failed.";
|
|
case ncclInternalError:
|
|
return "ncclInternalError: Internal check failed. This is either a bug in NCCL or due to memory corruption";
|
|
case ncclInvalidArgument:
|
|
return "ncclInvalidArgument: Invalid value for an argument (such as invalid pointer, device count, ip:host pair, etc).";
|
|
case ncclInvalidUsage:
|
|
return "ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).";
|
|
default:
|
|
break;
|
|
}
|
|
return "Unknown NCCL error";
|
|
}
|
|
} // namespace
|
|
// Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
|
|
// and ncclCommGetAsyncError() are not supported in earlier versions.
|
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
|
(NCCL_MINOR >= 4)
|
|
#define ENABLE_NCCL_ERROR_CHECKING
|
|
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
|
|
#define ENABLE_NCCL_ERROR_CHECKING
|
|
#endif
|
|
|
|
// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
|
|
// and ncclRecv() are not supported in earlier versions.
|
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
|
(NCCL_MINOR >= 7)
|
|
#define ENABLE_NCCL_P2P_SUPPORT
|
|
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
|
|
#define ENABLE_NCCL_P2P_SUPPORT
|
|
#endif
|
|
|
|
// Macro to throw on a non-successful NCCL return value.
|
|
#define C10D_NCCL_CHECK(cmd) \
|
|
do { \
|
|
ncclResult_t result = cmd; \
|
|
if (result != ncclSuccess) { \
|
|
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
|
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
|
|
"\n" + getNcclErrorDetailStr(result); \
|
|
TORCH_CHECK(false, err); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Macro to print and abort on a non-successful NCCL return value.
|
|
#define C10D_NCCL_ASSERT(cmd) \
|
|
do { \
|
|
ncclResult_t result = cmd; \
|
|
if (result != ncclSuccess) { \
|
|
std::string err = ncclGetErrorWithVersion(result); \
|
|
fprintf( \
|
|
stderr, \
|
|
"NCCL error in: %s:%d, %s\n", \
|
|
__FILE__, \
|
|
__LINE__, \
|
|
err.c_str()); \
|
|
abort(); \
|
|
} \
|
|
} while (0)
|
|
|
|
namespace c10d {
|
|
|
|
std::string getNcclVersion();
|
|
std::string ncclGetErrorWithVersion(ncclResult_t error);
|
|
|
|
// RAII wrapper for NCCL communicator
|
|
class NCCLComm {
|
|
public:
|
|
explicit NCCLComm(ncclComm_t ncclComm)
|
|
: ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess) {}
|
|
|
|
NCCLComm() : NCCLComm(nullptr) {}
|
|
|
|
~NCCLComm() noexcept {
|
|
// Add lock in this destructor, as aborted_ needs to be read after memory
|
|
// barrier here.
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (ncclComm_ && !aborted_) {
|
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
|
// Use ncclCommAbort instead of ncclCommDestroy here since
|
|
// ncclCommDestroy could block forever waiting for work to complete on
|
|
// the communicator.
|
|
C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
|
|
#else
|
|
C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
|
|
#endif
|
|
}
|
|
}
|
|
|
|
static std::shared_ptr<NCCLComm> create(
|
|
int numRanks,
|
|
int rank,
|
|
ncclUniqueId commId) {
|
|
auto comm = std::make_shared<NCCLComm>();
|
|
C10D_NCCL_CHECK(
|
|
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
|
|
comm->ncclId_ = commId;
|
|
comm->rank_ = rank;
|
|
return comm;
|
|
}
|
|
|
|
ncclUniqueId getNcclId() {
|
|
return ncclId_;
|
|
}
|
|
|
|
// Must not be copyable
|
|
NCCLComm(const NCCLComm&) = delete;
|
|
NCCLComm& operator=(const NCCLComm&) = delete;
|
|
|
|
// Do not support move assignment as there is no valid use case
|
|
NCCLComm& operator=(NCCLComm&& other) = delete;
|
|
|
|
// Move constructable
|
|
NCCLComm(NCCLComm&& other) {
|
|
// Using other's lock, as it reads other's states
|
|
// Can not use this.mutex_, as this object is being constructed.
|
|
std::unique_lock<std::mutex> lock(other.mutex_);
|
|
std::swap(ncclComm_, other.ncclComm_);
|
|
std::swap(aborted_, other.aborted_);
|
|
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
|
}
|
|
|
|
ncclComm_t getNcclComm() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (aborted_) {
|
|
TORCH_CHECK(false,
|
|
"NCCL communicator was aborted on rank " + std::to_string(rank_) +
|
|
".");
|
|
}
|
|
return ncclComm_;
|
|
}
|
|
|
|
void ncclCommAbort() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
|
if (aborted_) {
|
|
// Should not abort twice.
|
|
return;
|
|
}
|
|
|
|
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_));
|
|
aborted_ = true;
|
|
ncclComm_ = nullptr;
|
|
|
|
// Set an appropriate error so that we avoid using the communicator.
|
|
if (ncclAsyncErr_ == ncclSuccess) {
|
|
ncclAsyncErr_ = ncclSystemError;
|
|
}
|
|
#else
|
|
// This is a NOOP, if error checks are disabled.
|
|
return;
|
|
#endif
|
|
}
|
|
|
|
bool isAborted() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
return aborted_;
|
|
}
|
|
|
|
ncclResult_t checkForNcclError() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
|
if (ncclAsyncErr_ != ncclSuccess) {
|
|
return ncclAsyncErr_;
|
|
}
|
|
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_));
|
|
return ncclAsyncErr_;
|
|
#else
|
|
// Always return success, if error checks are disabled.
|
|
return ncclSuccess;
|
|
#endif
|
|
}
|
|
|
|
protected:
|
|
ncclComm_t ncclComm_;
|
|
// Unique nccl_id for this communicator.
|
|
ncclUniqueId ncclId_;
|
|
bool aborted_;
|
|
ncclResult_t ncclAsyncErr_;
|
|
mutable std::mutex mutex_;
|
|
// Rank that this communicator corresponds to.
|
|
int rank_;
|
|
};
|
|
|
|
} // namespace c10d
|
|
|
|
#endif // USE_C10D_NCCL
|