mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR adds the support for https://docs.nvidia.com/deeplearning/nccl/archives/nccl_21212/user-guide/docs/api/ops.html?highlight=premul#c.ncclRedOpCreatePreMulSum.
The major changes include
- convert enum ReduceOp to struct
- add premul sum specific paths to init.cpp and Ops.cpp.
note:
- For pip wheels / conda binaries to support this, ~~I think https://github.com/pytorch/pytorch/pull/79132 would be needed~~ https://github.com/pytorch/pytorch/pull/82775 landed
The commit titled "add nccl premul" whose current hash is cb99ad6744
was authored by @mcarilli and @ptrblck.
cc @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81272
Approved by: https://github.com/kwen2501
256 lines
8.7 KiB
C++
256 lines
8.7 KiB
C++
#pragma once
|
|
|
|
#ifdef USE_C10D_NCCL
|
|
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
|
|
#include <memory>
|
|
#include <mutex>
|
|
|
|
#include <nccl.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
namespace {
|
|
// Provides additional detail into NCCL error codes based on when these are
|
|
// thrown in the NCCL codebase.
|
|
const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::string> processGroupFailureReason = c10::nullopt) {
|
|
// Prioritize failure reason provided by PG NCCL first, as it can abort
|
|
// communicators when it encounters collective timeouts, etc.
|
|
if (processGroupFailureReason != c10::nullopt) {
|
|
return (*processGroupFailureReason).c_str();
|
|
}
|
|
switch (error) {
|
|
case ncclUnhandledCudaError:
|
|
return "ncclUnhandledCudaError: Call to CUDA function failed.";
|
|
case ncclSystemError:
|
|
return "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. "
|
|
"It can be also caused by unexpected exit of a remote peer, you can check NCCL warnings for failure reason and see if there is connection closure by a peer.";
|
|
case ncclInternalError:
|
|
return "ncclInternalError: Internal check failed. This is either a bug in NCCL or due to memory corruption";
|
|
case ncclInvalidArgument:
|
|
return "ncclInvalidArgument: Invalid value for an argument (such as invalid pointer, device count, ip:host pair, etc).";
|
|
case ncclInvalidUsage:
|
|
return "ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).";
|
|
default:
|
|
break;
|
|
}
|
|
return "Unknown NCCL error";
|
|
}
|
|
} // namespace
|
|
// Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
|
|
// and ncclCommGetAsyncError() are not supported in earlier versions.
|
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
|
(NCCL_MINOR >= 4)
|
|
#define ENABLE_NCCL_ERROR_CHECKING
|
|
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
|
|
#define ENABLE_NCCL_ERROR_CHECKING
|
|
#endif
|
|
|
|
// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
|
|
// and ncclRecv() are not supported in earlier versions.
|
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
|
(NCCL_MINOR >= 7)
|
|
#define ENABLE_NCCL_P2P_SUPPORT
|
|
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
|
|
#define ENABLE_NCCL_P2P_SUPPORT
|
|
#endif
|
|
|
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 11)
|
|
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
|
|
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
|
|
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
|
|
#endif
|
|
|
|
// Macro to throw on a non-successful NCCL return value.
|
|
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
|
do { \
|
|
ncclResult_t result = cmd; \
|
|
if (result != ncclSuccess) { \
|
|
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
|
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
|
|
"\n" + getNcclErrorDetailStr(result, failureReason); \
|
|
TORCH_CHECK(false, err); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Macro to print and abort on a non-successful NCCL return value.
|
|
#define C10D_NCCL_ASSERT(cmd) \
|
|
do { \
|
|
ncclResult_t result = cmd; \
|
|
if (result != ncclSuccess) { \
|
|
std::string err = ncclGetErrorWithVersion(result); \
|
|
fprintf( \
|
|
stderr, \
|
|
"NCCL error in: %s:%d, %s\n", \
|
|
__FILE__, \
|
|
__LINE__, \
|
|
err.c_str()); \
|
|
abort(); \
|
|
} \
|
|
} while (0)
|
|
|
|
namespace c10d {
|
|
|
|
std::string getNcclVersion();
|
|
std::string ncclGetErrorWithVersion(ncclResult_t error);
|
|
|
|
// RAII wrapper for NCCL communicator
|
|
class NCCLComm {
|
|
public:
|
|
explicit NCCLComm(ncclComm_t ncclComm)
|
|
: ncclComm_(ncclComm),
|
|
aborted_(false),
|
|
ncclAsyncErr_(ncclSuccess),
|
|
commFailureReason_(c10::nullopt) {}
|
|
|
|
NCCLComm() : NCCLComm(nullptr) {}
|
|
|
|
~NCCLComm() noexcept {
|
|
// Add lock in this destructor, as aborted_ needs to be read after memory
|
|
// barrier here.
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (ncclComm_ && !aborted_) {
|
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
|
// Use ncclCommAbort instead of ncclCommDestroy here since
|
|
// ncclCommDestroy could block forever waiting for work to complete on
|
|
// the communicator.
|
|
C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
|
|
#else
|
|
C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
|
|
#endif
|
|
}
|
|
}
|
|
|
|
static std::shared_ptr<NCCLComm> create(
|
|
int numRanks,
|
|
int rank,
|
|
ncclUniqueId commId) {
|
|
auto comm = std::make_shared<NCCLComm>();
|
|
C10D_NCCL_CHECK(
|
|
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
|
|
comm->ncclId_ = commId;
|
|
comm->rank_ = rank;
|
|
return comm;
|
|
}
|
|
|
|
ncclUniqueId getNcclId() {
|
|
return ncclId_;
|
|
}
|
|
|
|
// Must not be copyable
|
|
NCCLComm(const NCCLComm&) = delete;
|
|
NCCLComm& operator=(const NCCLComm&) = delete;
|
|
|
|
// Do not support move assignment as there is no valid use case
|
|
NCCLComm& operator=(NCCLComm&& other) = delete;
|
|
|
|
// Move constructable
|
|
NCCLComm(NCCLComm&& other) {
|
|
// Using other's lock, as it reads other's states
|
|
// Can not use this.mutex_, as this object is being constructed.
|
|
std::unique_lock<std::mutex> lock(other.mutex_);
|
|
std::swap(ncclComm_, other.ncclComm_);
|
|
std::swap(aborted_, other.aborted_);
|
|
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
|
}
|
|
|
|
ncclComm_t getNcclComm();
|
|
|
|
c10::optional<std::string> getNcclCommFailureReason() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
return commFailureReason_;
|
|
}
|
|
|
|
void ncclCommAbort(
|
|
c10::optional<std::string> commFailureReason = c10::nullopt) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
|
if (aborted_) {
|
|
// Should not abort twice.
|
|
return;
|
|
}
|
|
|
|
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
|
|
// timeout)
|
|
commFailureReason_ = commFailureReason;
|
|
|
|
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
|
|
aborted_ = true;
|
|
ncclComm_ = nullptr;
|
|
|
|
// Set an appropriate error so that we avoid using the communicator.
|
|
if (ncclAsyncErr_ == ncclSuccess) {
|
|
ncclAsyncErr_ = ncclSystemError;
|
|
}
|
|
#else
|
|
// This is a NOOP, if error checks are disabled.
|
|
return;
|
|
#endif
|
|
}
|
|
|
|
bool isAborted() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
return aborted_;
|
|
}
|
|
|
|
ncclResult_t checkForNcclError() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
|
if (ncclAsyncErr_ != ncclSuccess) {
|
|
return ncclAsyncErr_;
|
|
}
|
|
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
|
|
return ncclAsyncErr_;
|
|
#else
|
|
// Always return success, if error checks are disabled.
|
|
return ncclSuccess;
|
|
#endif
|
|
}
|
|
|
|
protected:
|
|
ncclComm_t ncclComm_;
|
|
// Unique nccl_id for this communicator.
|
|
ncclUniqueId ncclId_;
|
|
bool aborted_;
|
|
ncclResult_t ncclAsyncErr_;
|
|
mutable std::mutex mutex_;
|
|
// Rank that this communicator corresponds to.
|
|
int rank_;
|
|
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
|
|
// better error messaging.
|
|
c10::optional<std::string> commFailureReason_;
|
|
};
|
|
|
|
// Helper that automatically cleans up premul sums.
|
|
struct ncclRedOpRAII {
|
|
ncclRedOpRAII() {}
|
|
ncclRedOpRAII(ncclRedOp_t op) : op_(op) {}
|
|
ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm) :
|
|
op_(op), comm_(comm), premul_sum_(true) {}
|
|
ncclRedOpRAII(const ncclRedOpRAII&) = delete;
|
|
ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete;
|
|
ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() {
|
|
std::swap(tmp.op_, this->op_);
|
|
std::swap(tmp.comm_, this->comm_);
|
|
std::swap(tmp.premul_sum_, this->premul_sum_);
|
|
}
|
|
#if defined(ENABLE_NCCL_PREMUL_SUM_SUPPORT)
|
|
~ncclRedOpRAII() {
|
|
if (premul_sum_) {
|
|
ncclRedOpDestroy(op_, comm_);
|
|
}
|
|
}
|
|
#endif
|
|
operator ncclRedOp_t() const { return op_; }
|
|
ncclRedOp_t op_;
|
|
ncclComm_t comm_;
|
|
bool premul_sum_ = false;
|
|
};
|
|
|
|
|
|
} // namespace c10d
|
|
|
|
#endif // USE_C10D_NCCL
|