mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Try to land https://github.com/pytorch/pytorch/pull/136789/files on our end and fix any remaining issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144794 Approved by: https://github.com/kwen2501, https://github.com/eqy, https://github.com/atalman
This commit is contained in:
@ -252,6 +252,15 @@ class ProcessGroupNCCLInitTest(MultiProcessTestCase):
|
||||
x = torch.empty(1, device=self.device)
|
||||
c10d.all_reduce(x)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_scalable_init(self):
|
||||
os.environ["TORCH_NCCL_RANKS_PER_ROOT"] = "1"
|
||||
self._init_process_group(device_id=self.device)
|
||||
x = torch.empty(1, device=self.device)
|
||||
c10d.all_reduce(x)
|
||||
os.environ["TORCH_NCCL_RANKS_PER_ROOT"] = "0"
|
||||
|
||||
|
||||
class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
def _create_process_group_nccl(self, store, opts, device_id=None):
|
||||
|
@ -87,6 +87,35 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
|
||||
comm->initialized_ = !comm->nonBlocking_;
|
||||
return comm;
|
||||
}
|
||||
#ifdef NCCL_HAS_INIT_RANK_SCALABLE
|
||||
std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
|
||||
int numRanks,
|
||||
int rank,
|
||||
std::vector<ncclUniqueId>& commIds,
|
||||
ncclConfig_t& config) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
comm->nonBlocking_ = config.blocking == 0;
|
||||
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
|
||||
<< (comm->nonBlocking_ ? "nonblocking" : "blocking")
|
||||
<< " with scalable init.";
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommInitRankScalable(
|
||||
&(comm->ncclComm_),
|
||||
numRanks,
|
||||
rank,
|
||||
commIds.size(),
|
||||
commIds.data(),
|
||||
&config),
|
||||
std::nullopt);
|
||||
// Only the first ncclUniqueId will be used to create the
|
||||
// communicator hash id, which is used to identify the communicator
|
||||
// in the log file and in the replay tool.
|
||||
comm->ncclId_ = commIds[0];
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = !comm->nonBlocking_;
|
||||
return comm;
|
||||
}
|
||||
#endif // NCCL_HAS_INIT_RANK_SCALABLE
|
||||
#endif // NCCL_HAS_CONFIG
|
||||
|
||||
ncclComm_t NCCLComm::getNcclComm() {
|
||||
|
@ -26,6 +26,10 @@ constexpr int64_t kCommInitBusyWaitMillis = 2;
|
||||
#define NCCL_HAS_COMM_SPLIT
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 23, 0)
|
||||
#define NCCL_HAS_INIT_RANK_SCALABLE
|
||||
#endif
|
||||
|
||||
// ncclGetLastError() is enabled only for NCCL versions 2.13+
|
||||
// ncclRemoteError only exists in NCCL versions 2.13+
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 13, 0)
|
||||
@ -212,6 +216,13 @@ class NCCLComm {
|
||||
ncclUniqueId commId,
|
||||
at::DeviceIndex deviceIndex,
|
||||
ncclConfig_t& config);
|
||||
#ifdef NCCL_HAS_INIT_RANK_SCALABLE
|
||||
static std::shared_ptr<NCCLComm> create_scalable(
|
||||
int numRanks,
|
||||
int rank,
|
||||
std::vector<ncclUniqueId>& commIds,
|
||||
ncclConfig_t& config);
|
||||
#endif // NCCL_HAS_INIT_RANK_SCALABLE
|
||||
#endif // NCCL_HAS_CONFIG
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
|
@ -1,5 +1,6 @@
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <exception>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
@ -2157,6 +2158,32 @@ void ProcessGroupNCCL::checkAndSetRemoteError() {
|
||||
}
|
||||
}
|
||||
|
||||
// NCCL recommends to evenly distribute ncclUniqueIds accross the ranks
|
||||
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config
|
||||
// Let’s consider an example where:
|
||||
// nRanks = 10 (total ranks),
|
||||
// nIds = 3 (roots),
|
||||
// rmr = 10 % 3 = 1 (1 larger group),
|
||||
// rpr = 10 / 3 = 3 (base number of ranks per group).
|
||||
// rlim = 4
|
||||
// Output root:
|
||||
// For ranks [0, 1, 2, 3], root rank is 0 and index is 0.
|
||||
// For ranks [4, 5, 6], root rank is 4 and index is 1.
|
||||
// For ranks [7, 8, 9], root rank is 7 and index is 2.
|
||||
static int getRootIndex(const int rank, const int nRanks, const int nIds) {
|
||||
const int rmr = nRanks % nIds;
|
||||
const int rpr = nRanks / nIds;
|
||||
// For the first rmr roots, we assign one more rank to the root.
|
||||
const int rlim = rmr * (rpr + 1);
|
||||
if (rank < rlim) {
|
||||
// Root with `rpr + 1` ranks, (0, 1, 2, ..., rmr - 1).
|
||||
return rank % (rpr + 1) ? -1 : rank / (rpr + 1);
|
||||
} else {
|
||||
// Root with `rpr` ranks, (rmr, rmr + 1, ..., nIds - 1).
|
||||
return (rank - rlim) % rpr ? -1 : ((rank - rlim) / rpr) + rmr;
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::watchdogHandler() {
|
||||
bool done = false;
|
||||
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
|
||||
@ -2539,6 +2566,63 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
|
||||
}
|
||||
}
|
||||
|
||||
// We want to all-gather unique NCCL IDs from all roots using TCPStore.
|
||||
// This is first done by setting the ID by each root and then `multiGet` by all
|
||||
// ranks.
|
||||
void ProcessGroupNCCL::allgatherUniqueNCCLIDs(
|
||||
int rootIdx,
|
||||
ncclUniqueId* ncclID,
|
||||
std::vector<ncclUniqueId>& ncclIDs) {
|
||||
std::vector<std::string> storeKeys;
|
||||
std::vector<std::vector<uint8_t>> results;
|
||||
for (size_t r = 0; r < ncclIDs.size(); r++) {
|
||||
storeKeys.emplace_back("UniqueNCCLID:" + std::to_string(r));
|
||||
}
|
||||
// For non-root rank, rootIdx is set to -1.
|
||||
if (rootIdx >= 0) {
|
||||
auto vec = std::vector<uint8_t>(
|
||||
reinterpret_cast<uint8_t*>(ncclID),
|
||||
reinterpret_cast<uint8_t*>(ncclID) + NCCL_UNIQUE_ID_BYTES);
|
||||
store_->set(storeKeys[rootIdx], vec);
|
||||
}
|
||||
try {
|
||||
results = store_->multiGet(storeKeys);
|
||||
} catch (const std::exception& e) {
|
||||
nlohmann::json json_vec = storeKeys;
|
||||
std::string exceptionMsg = c10::str(
|
||||
"[",
|
||||
rank_,
|
||||
"] is setting up NCCL communicators and "
|
||||
"retrieving ncclUniqueId from roots via TCPStore by key '",
|
||||
json_vec.dump(),
|
||||
"', but got error: ");
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
exceptionMsg + e.what() +
|
||||
". This may indicate a possible application crash on rank 0 or a network set up issue.");
|
||||
} catch (...) {
|
||||
nlohmann::json json_vec = storeKeys;
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
c10::str(
|
||||
"Unknown exception while [",
|
||||
rank_,
|
||||
"] is setting up NCCL communicators and "
|
||||
"retrieving ncclUniqueIds from roots via TCPStore by key '",
|
||||
json_vec.dump(),
|
||||
"'",
|
||||
". This may indicate a possible application crash on rank 0 or a network set up issue."));
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < ncclIDs.size(); r++) {
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
results[r].size() == NCCL_UNIQUE_ID_BYTES,
|
||||
"Invalid size for ncclUniqueId");
|
||||
std::memcpy(&ncclIDs[r], results[r].data(), results[r].size());
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
|
||||
@ -2680,13 +2764,57 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
|
||||
}
|
||||
#endif // NCCL_HAS_COMM_SPLIT
|
||||
|
||||
bool useScalableInit = false;
|
||||
// (nranks / nroots) == 128 was the default NCCL recommended
|
||||
// accoring to
|
||||
// https://github.com/pytorch/pytorch/pull/136789#discussion_r1779171615.
|
||||
auto ranksPerRoot = getCvarInt(TORCH_NCCL_RANKS_PER_ROOT, 128);
|
||||
#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG)
|
||||
useScalableInit = !singleP2POp && (getSize() > ranksPerRoot);
|
||||
#endif // NCCL_HAS_INIT_RANK_SCALABLE && NCCL_HAS_CONFIG
|
||||
|
||||
if (useScalableInit) {
|
||||
auto numRoots = (getSize() + ranksPerRoot - 1) / ranksPerRoot;
|
||||
std::vector<ncclUniqueId> ncclIDs(numRoots);
|
||||
|
||||
if (!ncclComm) {
|
||||
auto rootIdx = getRootIndex(rank_, getSize(), numRoots);
|
||||
// We only need to get unique IDs for roots. For non-root rank, index is
|
||||
// set to -1.
|
||||
if (rootIdx >= 0) {
|
||||
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt);
|
||||
}
|
||||
// We only need to all-gather the ncclID if the rank is root.
|
||||
auto timeStarted = std::chrono::steady_clock::now();
|
||||
allgatherUniqueNCCLIDs(rootIdx, &ncclID, ncclIDs);
|
||||
auto timerDeltaMs =
|
||||
std::chrono::duration_cast<std::chrono::duration<double>>(
|
||||
std::chrono::steady_clock::now() - timeStarted)
|
||||
.count() *
|
||||
1000;
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "ProcessGroupNCCL all-gather unique IDs through store took "
|
||||
<< timerDeltaMs << " ms";
|
||||
#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG)
|
||||
ncclComm =
|
||||
NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config);
|
||||
#else
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
c10::str(
|
||||
logPrefix(),
|
||||
"create_scalable is called when useScalableInit is enabled but ",
|
||||
"neither NCCL_HAS_INIT_RANK_SCALABLE nor NCCL_HAS_CONFIG is not defined, this should not happen "));
|
||||
#endif // NCCL_HAS_INIT_RANK_SCALABLE
|
||||
}
|
||||
} else {
|
||||
// To simplify conditional nesting, just create the ncclComms[i]
|
||||
// entry if it hasn't been yet rather than untangling the
|
||||
// conditions that might have resulted in a split above.
|
||||
if (!ncclComm) {
|
||||
if (getCvarBool(TORCH_NCCL_BCAST_UNIQUEID, true) && !isSendRecvSelf) {
|
||||
// For point-to-point communication, lower rank of the two will get unique
|
||||
// id.
|
||||
// For point-to-point communication, lower rank of the two will get
|
||||
// unique id.
|
||||
if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
|
||||
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt);
|
||||
}
|
||||
@ -2705,12 +2833,13 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_CONFIG
|
||||
ncclComm =
|
||||
NCCLComm::create(numRanks, rank, ncclID, deviceIndex, options_->config);
|
||||
ncclComm = NCCLComm::create(
|
||||
numRanks, rank, ncclID, deviceIndex, options_->config);
|
||||
#else
|
||||
ncclComm = NCCLComm::create(numRanks, rank, ncclID, deviceIndex);
|
||||
#endif // NCCL_HAS_CONFIG
|
||||
}
|
||||
}
|
||||
|
||||
// Creates the NCCL streams
|
||||
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
|
||||
|
@ -126,6 +126,10 @@ static std::vector<std::string> TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = {
|
||||
static std::vector<std::string> TORCH_NCCL_CUDA_EVENT_CACHE = {
|
||||
"TORCH_NCCL_CUDA_EVENT_CACHE"};
|
||||
|
||||
// Control the number of ranks each root can cover during NCCL comm init.
|
||||
static std::vector<std::string> TORCH_NCCL_RANKS_PER_ROOT = {
|
||||
"TORCH_NCCL_RANKS_PER_ROOT"};
|
||||
|
||||
static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};
|
||||
|
||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
|
||||
@ -804,6 +808,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
const std::string& devicesKey,
|
||||
int p2pRank);
|
||||
|
||||
// Helper that allgathers nccl unique IDs to all ranks through the store
|
||||
void allgatherUniqueNCCLIDs(
|
||||
int rootIdx,
|
||||
ncclUniqueId* ncclID,
|
||||
std::vector<ncclUniqueId>& ncclIDs);
|
||||
|
||||
// Helper that looks up the cached NCCL communicators only
|
||||
std::shared_ptr<NCCLComm> getNCCLComm(const std::string& deviceKey);
|
||||
|
||||
|
Reference in New Issue
Block a user