[nccl] Wrap nccl code update with version check (#130419)

Fixes the issue that cannot build pytorch with nccl < 2.13 after https://github.com/pytorch/pytorch/issues/128756

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130419
Approved by: https://github.com/eqy, https://github.com/malfet
This commit is contained in:
Yichen Yan
2024-08-02 01:22:05 +00:00
committed by PyTorch MergeBot
parent 50ed6ce277
commit ef426d5183
2 changed files with 9 additions and 2 deletions

View File

@ -17,9 +17,12 @@
#include <unordered_map>
#if !defined(USE_ROCM) && \
((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14)))
((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 13)))
#define NCCL_HAS_REMOTE_ERROR 1
#if (NCCL_MAJOR > 2) || (NCCL_MINOR >= 14)
#define NCCL_HAS_COMM_NONBLOCKING 1
#endif
#endif
ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
@ -47,8 +50,10 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
return ncclResult_t::ncclInvalidArgument;
case torch::cuda::nccl::ncclResult::InvalidUsage:
return ncclResult_t::ncclInvalidUsage;
#ifdef NCCL_HAS_REMOTE_ERROR
case torch::cuda::nccl::ncclResult::RemoteError:
return ncclResult_t::ncclRemoteError;
#endif
#ifdef NCCL_HAS_COMM_NONBLOCKING
case torch::cuda::nccl::ncclResult::InProgress:
return ncclResult_t::ncclInProgress;
@ -74,8 +79,10 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
return torch::cuda::nccl::ncclResult::InvalidArgument;
case ncclInvalidUsage:
return torch::cuda::nccl::ncclResult::InvalidUsage;
#ifdef NCCL_HAS_REMOTE_ERROR
case ncclRemoteError:
return torch::cuda::nccl::ncclResult::RemoteError;
#endif
#ifdef NCCL_HAS_COMM_NONBLOCKING
case ncclInProgress:
return torch::cuda::nccl::ncclResult::InProgress;

View File

@ -327,7 +327,6 @@ class NCCLComm {
comm->initialized_ = isInitialized;
return comm;
}
#endif
static std::shared_ptr<NCCLComm> split(
NCCLComm* source,
@ -335,6 +334,7 @@ class NCCLComm {
int rank,
ncclConfig_t& config,
std::vector<uint64_t>& ranks_ull);
#endif
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump() {