[c10d][NCCL] Implement ncclCommInitRankScalable (merging #136789) (#144794)

Try to land https://github.com/pytorch/pytorch/pull/136789/files on our end and fix any remaining issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144794
Approved by: https://github.com/kwen2501, https://github.com/eqy, https://github.com/atalman
This commit is contained in:
fduwjj
2025-01-30 22:30:02 -08:00
committed by PyTorch MergeBot
parent af2a39849d
commit eb029fba13
5 changed files with 203 additions and 15 deletions

View File

@ -252,6 +252,15 @@ class ProcessGroupNCCLInitTest(MultiProcessTestCase):
x = torch.empty(1, device=self.device)
c10d.all_reduce(x)
@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_scalable_init(self):
os.environ["TORCH_NCCL_RANKS_PER_ROOT"] = "1"
self._init_process_group(device_id=self.device)
x = torch.empty(1, device=self.device)
c10d.all_reduce(x)
os.environ["TORCH_NCCL_RANKS_PER_ROOT"] = "0"
class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
def _create_process_group_nccl(self, store, opts, device_id=None):

View File

@ -87,6 +87,35 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
comm->initialized_ = !comm->nonBlocking_;
return comm;
}
#ifdef NCCL_HAS_INIT_RANK_SCALABLE
std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
int numRanks,
int rank,
std::vector<ncclUniqueId>& commIds,
ncclConfig_t& config) {
auto comm = std::make_shared<NCCLComm>();
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
<< (comm->nonBlocking_ ? "nonblocking" : "blocking")
<< " with scalable init.";
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommInitRankScalable(
&(comm->ncclComm_),
numRanks,
rank,
commIds.size(),
commIds.data(),
&config),
std::nullopt);
// Only the first ncclUniqueId will be used to create the
// communicator hash id, which is used to identify the communicator
// in the log file and in the replay tool.
comm->ncclId_ = commIds[0];
comm->rank_ = rank;
comm->initialized_ = !comm->nonBlocking_;
return comm;
}
#endif // NCCL_HAS_INIT_RANK_SCALABLE
#endif // NCCL_HAS_CONFIG
ncclComm_t NCCLComm::getNcclComm() {

View File

@ -26,6 +26,10 @@ constexpr int64_t kCommInitBusyWaitMillis = 2;
#define NCCL_HAS_COMM_SPLIT
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 23, 0)
#define NCCL_HAS_INIT_RANK_SCALABLE
#endif
// ncclGetLastError() is enabled only for NCCL versions 2.13+
// ncclRemoteError only exists in NCCL versions 2.13+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 13, 0)
@ -212,6 +216,13 @@ class NCCLComm {
ncclUniqueId commId,
at::DeviceIndex deviceIndex,
ncclConfig_t& config);
#ifdef NCCL_HAS_INIT_RANK_SCALABLE
static std::shared_ptr<NCCLComm> create_scalable(
int numRanks,
int rank,
std::vector<ncclUniqueId>& commIds,
ncclConfig_t& config);
#endif // NCCL_HAS_INIT_RANK_SCALABLE
#endif // NCCL_HAS_CONFIG
#ifdef NCCL_HAS_COMM_SPLIT

View File

@ -1,5 +1,6 @@
#ifdef USE_C10D_NCCL
#include <nlohmann/json.hpp>
#include <exception>
#include <map>
#include <mutex>
@ -2157,6 +2158,32 @@ void ProcessGroupNCCL::checkAndSetRemoteError() {
}
}
// NCCL recommends to evenly distribute ncclUniqueIds accross the ranks
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config
// Lets consider an example where:
// nRanks = 10 (total ranks),
// nIds = 3 (roots),
// rmr = 10 % 3 = 1 (1 larger group),
// rpr = 10 / 3 = 3 (base number of ranks per group).
// rlim = 4
// Output root:
// For ranks [0, 1, 2, 3], root rank is 0 and index is 0.
// For ranks [4, 5, 6], root rank is 4 and index is 1.
// For ranks [7, 8, 9], root rank is 7 and index is 2.
static int getRootIndex(const int rank, const int nRanks, const int nIds) {
const int rmr = nRanks % nIds;
const int rpr = nRanks / nIds;
// For the first rmr roots, we assign one more rank to the root.
const int rlim = rmr * (rpr + 1);
if (rank < rlim) {
// Root with `rpr + 1` ranks, (0, 1, 2, ..., rmr - 1).
return rank % (rpr + 1) ? -1 : rank / (rpr + 1);
} else {
// Root with `rpr` ranks, (rmr, rmr + 1, ..., nIds - 1).
return (rank - rlim) % rpr ? -1 : ((rank - rlim) / rpr) + rmr;
}
}
void ProcessGroupNCCL::watchdogHandler() {
bool done = false;
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
@ -2539,6 +2566,63 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
}
}
// We want to all-gather unique NCCL IDs from all roots using TCPStore.
// This is first done by setting the ID by each root and then `multiGet` by all
// ranks.
void ProcessGroupNCCL::allgatherUniqueNCCLIDs(
int rootIdx,
ncclUniqueId* ncclID,
std::vector<ncclUniqueId>& ncclIDs) {
std::vector<std::string> storeKeys;
std::vector<std::vector<uint8_t>> results;
for (size_t r = 0; r < ncclIDs.size(); r++) {
storeKeys.emplace_back("UniqueNCCLID:" + std::to_string(r));
}
// For non-root rank, rootIdx is set to -1.
if (rootIdx >= 0) {
auto vec = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(ncclID),
reinterpret_cast<uint8_t*>(ncclID) + NCCL_UNIQUE_ID_BYTES);
store_->set(storeKeys[rootIdx], vec);
}
try {
results = store_->multiGet(storeKeys);
} catch (const std::exception& e) {
nlohmann::json json_vec = storeKeys;
std::string exceptionMsg = c10::str(
"[",
rank_,
"] is setting up NCCL communicators and "
"retrieving ncclUniqueId from roots via TCPStore by key '",
json_vec.dump(),
"', but got error: ");
C10_THROW_ERROR(
DistBackendError,
exceptionMsg + e.what() +
". This may indicate a possible application crash on rank 0 or a network set up issue.");
} catch (...) {
nlohmann::json json_vec = storeKeys;
C10_THROW_ERROR(
DistBackendError,
c10::str(
"Unknown exception while [",
rank_,
"] is setting up NCCL communicators and "
"retrieving ncclUniqueIds from roots via TCPStore by key '",
json_vec.dump(),
"'",
". This may indicate a possible application crash on rank 0 or a network set up issue."));
}
for (size_t r = 0; r < ncclIDs.size(); r++) {
TORCH_CHECK_WITH(
DistBackendError,
results[r].size() == NCCL_UNIQUE_ID_BYTES,
"Invalid size for ncclUniqueId");
std::memcpy(&ncclIDs[r], results[r].data(), results[r].size());
}
}
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
std::lock_guard<std::mutex> lock(mutex_);
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
@ -2680,36 +2764,81 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
}
#endif // NCCL_HAS_COMM_SPLIT
// To simplify conditional nesting, just create the ncclComms[i]
// entry if it hasn't been yet rather than untangling the
// conditions that might have resulted in a split above.
if (!ncclComm) {
if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) {
// For point-to-point communication, lower rank of the two will get unique
// id.
if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
bool useScalableInit = false;
// (nranks / nroots) == 128 was the default NCCL recommended
// accoring to
// https://github.com/pytorch/pytorch/pull/136789#discussion_r1779171615.
auto ranksPerRoot = getCvarInt(TORCH_NCCL_RANKS_PER_ROOT, 128);
#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG)
useScalableInit = !singleP2POp && (getSize() > ranksPerRoot);
#endif // NCCL_HAS_INIT_RANK_SCALABLE && NCCL_HAS_CONFIG
if (useScalableInit) {
auto numRoots = (getSize() + ranksPerRoot - 1) / ranksPerRoot;
std::vector<ncclUniqueId> ncclIDs(numRoots);
if (!ncclComm) {
auto rootIdx = getRootIndex(rank_, getSize(), numRoots);
// We only need to get unique IDs for roots. For non-root rank, index is
// set to -1.
if (rootIdx >= 0) {
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt);
}
// Broadcast so that each process can have a unique NCCL ID
// We only need to all-gather the ncclID if the rank is root.
auto timeStarted = std::chrono::steady_clock::now();
broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank);
allgatherUniqueNCCLIDs(rootIdx, &ncclID, ncclIDs);
auto timerDeltaMs =
std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::steady_clock::now() - timeStarted)
.count() *
1000;
LOG(INFO) << logPrefix()
<< "ProcessGroupNCCL broadcast unique ID through store took "
<< "ProcessGroupNCCL all-gather unique IDs through store took "
<< timerDeltaMs << " ms";
#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG)
ncclComm =
NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config);
#else
C10_THROW_ERROR(
DistBackendError,
c10::str(
logPrefix(),
"create_scalable is called when useScalableInit is enabled but ",
"neither NCCL_HAS_INIT_RANK_SCALABLE nor NCCL_HAS_CONFIG is not defined, this should not happen "));
#endif // NCCL_HAS_INIT_RANK_SCALABLE
}
} else {
// To simplify conditional nesting, just create the ncclComms[i]
// entry if it hasn't been yet rather than untangling the
// conditions that might have resulted in a split above.
if (!ncclComm) {
if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) {
// For point-to-point communication, lower rank of the two will get
// unique id.
if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt);
}
// Broadcast so that each process can have a unique NCCL ID
auto timeStarted = std::chrono::steady_clock::now();
broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank);
auto timerDeltaMs =
std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::steady_clock::now() - timeStarted)
.count() *
1000;
LOG(INFO) << logPrefix()
<< "ProcessGroupNCCL broadcast unique ID through store took "
<< timerDeltaMs << " ms";
}
#ifdef NCCL_HAS_CONFIG
ncclComm =
NCCLComm::create(numRanks, rank, ncclID, deviceIndex, options_->config);
ncclComm = NCCLComm::create(
numRanks, rank, ncclID, deviceIndex, options_->config);
#else
ncclComm = NCCLComm::create(numRanks, rank, ncclID, deviceIndex);
ncclComm = NCCLComm::create(numRanks, rank, ncclID, deviceIndex);
#endif // NCCL_HAS_CONFIG
}
}
// Creates the NCCL streams

View File

@ -126,6 +126,10 @@ static std::vector<std::string> TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = {
static std::vector<std::string> TORCH_NCCL_CUDA_EVENT_CACHE = {
"TORCH_NCCL_CUDA_EVENT_CACHE"};
// Control the number of ranks each root can cover during NCCL comm init.
static std::vector<std::string> TORCH_NCCL_RANKS_PER_ROOT = {
"TORCH_NCCL_RANKS_PER_ROOT"};
static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};
constexpr const char* NCCL_BACKEND_NAME = "nccl";
@ -804,6 +808,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
const std::string& devicesKey,
int p2pRank);
// Helper that allgathers nccl unique IDs to all ranks through the store
void allgatherUniqueNCCLIDs(
int rootIdx,
ncclUniqueId* ncclID,
std::vector<ncclUniqueId>& ncclIDs);
// Helper that looks up the cached NCCL communicators only
std::shared_ptr<NCCLComm> getNCCLComm(const std::string& deviceKey);