Opportunistically use ncclCommSplit when creating new NCCL groups (#112889)

Currently `ncclCommInitRankConfig` is always used when creating new
communicator groups.  This is wasteful as it creates non-shared pairs
of endpoint queues as well as costs time to re-establish
communication.

This change is transparent and opportunistic; when `dist.new_group` is
called, it will use the existing, healthy world process group to
select the right ranks to include in the process group.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112889
Approved by: https://github.com/kwen2501
This commit is contained in:
Chip Turner
2023-11-21 21:03:48 +00:00
committed by PyTorch MergeBot
parent 3b108a150a
commit 64a5372e6c
7 changed files with 214 additions and 23 deletions

View File

@ -31,12 +31,20 @@ class NCCLTestBase {
pg_ = std::move(other.pg_);
}
::c10d::ProcessGroupNCCL& getProcessGroup() {
return *pg_;
std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
return pg_;
}
void initialize(int rank, int size) {
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
::c10::intrusive_ptr<::c10d::Store>& getProcessGroupStore() {
return store_;
}
void initialize(
int rank,
int size,
c10::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
c10::nullopt) {
store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts =
c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
@ -45,14 +53,22 @@ class NCCLTestBase {
c10d::TORCH_ENABLE_NCCL_HEALTH_CHECK[0].c_str(),
"1",
/* overwrite */ 1);
#ifdef NCCL_HAS_COMM_SPLIT
if (split_from) {
opts->split_from = *split_from;
opts->split_color = ++color_;
}
#endif
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
new ::c10d::ProcessGroupNCCL(store_, rank, size, std::move(opts)));
}
protected:
std::string path_;
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
std::chrono::milliseconds pgTimeout_;
::c10::intrusive_ptr<::c10d::Store> store_;
int color_{1};
};
class NCCLTest : public NCCLTestBase {
@ -718,9 +734,9 @@ void testSequenceNumInit(
auto runTest = [&](int i) {
NCCLTest test(path, worldSize);
test.initialize(i, worldSize);
test.getProcessGroup().setSequenceNumberForGroup();
test.getProcessGroup()->setSequenceNumberForGroup();
std::lock_guard<std::mutex> lock(m);
auto seqNum = test.getProcessGroup().getSequenceNumberForGroup();
auto seqNum = test.getProcessGroup()->getSequenceNumberForGroup();
nums.insert(seqNum);
};
std::vector<std::thread> threads;
@ -877,11 +893,55 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) {
auto test = NCCLTestBase(file.path);
test.initialize(rank_, size_);
EXPECT_EQ(
test.getProcessGroup().getBackendName(),
test.getProcessGroup()->getBackendName(),
std::string(c10d::NCCL_BACKEND_NAME));
}
}
TEST_F(ProcessGroupNCCLTest, testSplittingCommunicator) {
if (skipTest()) {
return;
}
TemporaryFile file;
auto test1 = BroadcastNCCLTest(file.path, size_);
test1.initialize(rank_, size_);
auto test2 = BroadcastNCCLTest(file.path, size_);
test2.initialize(rank_, size_, test1.getProcessGroup());
// Steal the broadcast test and issue it for both of our groups.
// This ensures consistent full collective communication. TODO:
// maybe refactor the guts rather than copy-pasta, but it may not be
// worth it.
for (auto test : {&test1, &test2}) {
const int numDevices = test->numDevices();
// try every permutation of root rank and root tensor
for (const auto rootRank : c10::irange(size_)) {
for (const auto rootTensor : c10::irange(numDevices)) {
auto work = test->run(rootRank, rootTensor);
test->wait(work);
// Check results
const auto expected = (rootRank * numDevices + rootTensor);
const auto tensors = test->getTensors();
for (const auto& tensor : tensors) {
const auto* const data = tensor.data_ptr<float>();
for (const auto k : c10::irange(tensor.numel())) {
EXPECT_EQ(data[k], expected)
<< "Broadcast outputs do not match expected outputs";
}
}
}
}
}
// Now that we've run full operations on both the original and split process
// group, ensure we saw exactly as many splits as we expected: 0 in the
// original process group, and one per device in the second.
EXPECT_EQ(test2.getProcessGroup()->getCommSplitCounter(), 0);
EXPECT_EQ(test1.getProcessGroup()->getCommSplitCounter(), test1.numDevices());
}
#ifdef IS_NCCL_EXP
TEST_F(ProcessGroupNCCLTest, testSparseAllreduce) {
if (skipTest()) {

View File

@ -1272,6 +1272,27 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
# Verification
self.assertEqual(torch.arange(self.world_size), output_t)
@requires_nccl()
def test_comm_split_optimization(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
# Test lazy splitting behavior across each per-device backend.
for device in self.rank_to_GPU[self.rank]:
backend = pg._get_backend(torch.device(device))
# split doesn't happen unless the original process group has lazily
# created communicators, so first verify we haven't split even when
# making the new group and running an operation on the original pg.
ng = c10d.new_group()
tensor = torch.tensor([self.rank]).cuda(device)
pg.broadcast(tensor, 0)
self.assertEqual(backend.comm_split_count(), 0)
# The new group will force a split of the original on first use.
ng.broadcast(tensor, 0)
self.assertEqual(backend.comm_split_count(), 1)
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
):
@ -3676,7 +3697,6 @@ class NCCLTraceTest(MultiProcessTestCase):
if __name__ == "__main__":
assert (
not torch.cuda._initialized

View File

@ -17,6 +17,11 @@
#define NCCL_HAS_COMM_NONBLOCKING
#endif
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 18)
#define NCCL_HAS_COMM_SPLIT
#endif
// ncclGetLastError() is enabled only for NCCL versions 2.13+
// ncclRemoteError only exists in NCCL versions 2.13+
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
@ -246,6 +251,22 @@ class NCCLComm {
}
#endif
#ifdef NCCL_HAS_COMM_SPLIT
static std::shared_ptr<NCCLComm> split(
NCCLComm* source,
int color_id,
int rank,
ncclConfig_t& config) {
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommSplit(
source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
c10::nullopt);
++source->ncclCommSplitCounter_;
return comm;
}
#endif
ncclUniqueId getNcclId() {
return ncclId_;
}
@ -325,6 +346,10 @@ class NCCLComm {
return aborted_;
}
uint64_t getCommSplitCounter() const {
return ncclCommSplitCounter_;
}
ncclResult_t checkForNcclError() {
std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
@ -401,6 +426,7 @@ class NCCLComm {
// Unique nccl_id for this communicator.
ncclUniqueId ncclId_;
bool aborted_;
uint64_t ncclCommSplitCounter_{0};
ncclResult_t ncclAsyncErr_;
mutable std::mutex mutex_;
// Rank that this communicator corresponds to.

View File

@ -1898,12 +1898,41 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
int deviceIndex = devices[i].index();
gpuGuard.set_index(deviceIndex);
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
#else
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
#ifdef NCCL_HAS_COMM_SPLIT
if (options_->split_from) {
TORCH_CHECK(
options_->split_color != 0,
"Must specify a non-zero color when splitting");
// Find a valid, healthy communicator to split from if possible.
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
auto& other_comms = options_->split_from->devNCCLCommMap_;
auto dit = other_comms.find(devicesKey);
if (dit != other_comms.end() && !dit->second.empty()) {
TORCH_INTERNAL_ASSERT(
dit->second.size() == ncclComms.size(),
"split_from->devNCCLCommMap_ should be empty or the same size as ncclComms!");
if (dit->second[i] && !dit->second[i]->isAborted()) {
ncclComms[i] = NCCLComm::split(
dit->second[i].get(),
options_->split_color,
rank,
options_->config);
}
}
}
#endif
// To simplify conditioonal nesting, just create the ncclComms[i]
// entry if it hasn't been yet rather than untangling the
// conditions that might have resulted in a split above.
if (!ncclComms[i]) {
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
#else
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
#endif
}
// Creates the NCCL streams
streamVal.push_back(
at::cuda::getStreamFromPool(options_->is_high_priority_stream));
@ -1948,9 +1977,6 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
std::make_tuple(devicesKey),
std::make_tuple(devices.size()));
// Hold the lock before modifying the cache.
std::lock_guard<std::mutex> lock(mutex_);
// Record the communicators based on ncclUniqueId.
ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms);
@ -1994,9 +2020,20 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
it = devNCCLCommMap_.find(devicesKey);
TORCH_INTERNAL_ASSERT(
it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
return it->second;
}
uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
uint64_t ret = 0;
for (const auto& i : ncclIdToCommMap_) {
for (const auto& j : i.second) {
ret += j->getCommSplitCounter();
}
}
return ret;
}
namespace {
// Check validity of tensor

View File

@ -342,6 +342,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Configure ranks
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
#endif
// Optional "parent" backend and color to create communicators from
// via `ncclCommSplit`
#ifdef NCCL_HAS_COMM_SPLIT
std::shared_ptr<ProcessGroupNCCL> split_from;
int64_t split_color{0};
#endif
};
// If you wish to create multiple process groups, each with a potentially
@ -510,6 +517,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// may indicate that there is some sort of collective desynchronization.
uint64_t getSequenceNumberForGroup() override;
// Return the total number of splits the communicators held by this process
// group have performed.
uint64_t getCommSplitCounter() const;
void registerOnCompletionHook(
std::function<void(std::shared_ptr<WorkInfo>)>&& hook) override;
void waitForPendingWorks() override;

View File

@ -2290,6 +2290,9 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::call_guard<py::gil_scoped_release>())
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
.def(
"comm_split_count",
&::c10d::ProcessGroupNCCL::getCommSplitCounter)
.def_property_readonly(
"options", &::c10d::ProcessGroupNCCL::getOptions)
.def_property_readonly(
@ -2354,15 +2357,18 @@ Example::
)")
.def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
#ifdef NCCL_HAS_COMM_CTA_CGA
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
#endif
.def_readwrite(
"is_high_priority_stream",
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config);
#else
#ifdef NCCL_HAS_COMM_SPLIT
.def_readwrite(
"is_high_priority_stream",
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
.def_readwrite(
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
#endif
;
#endif

View File

@ -8,6 +8,7 @@ import io
import logging
import os
import pickle
import sys
import time
import warnings
from collections import namedtuple
@ -1314,7 +1315,29 @@ def _new_process_group_helper(
pg_options.is_high_priority_stream = False
pg_options._timeout = timeout
backend_class = ProcessGroupNCCL(backend_prefix_store, group_rank, group_size, pg_options)
# If our new group includes all ranks, we can reduce
# overhead by splitting the communicator (`nccCommSplit`).
# TODO: support this in the general case by calling
# `nccCommSplit` with `NCCL_SPLIT_NOCOLOR` for the ranks
# not in the communicator.
split_from = None
if (
is_initialized()
and _world.default_pg._get_backend_name() == Backend.NCCL
and len(global_ranks_in_group) == _world.default_pg.size()
):
# If possible, find a backend to split from by peeling
# process group wrappers from the world's default pg.
split_from = _world.default_pg._get_backend(_get_pg_default_device())
while isinstance(split_from, _ProcessGroupWrapper):
split_from = split_from.wrapped_pg
if split_from:
pg_options.split_from = split_from
pg_options.split_color = _process_group_color(global_ranks_in_group)
backend_class = ProcessGroupNCCL(
backend_prefix_store, group_rank, group_size, pg_options)
backend_type = ProcessGroup.BackendType.NCCL
elif backend_str == Backend.UCC and is_ucc_available():
# TODO: once UCC plugin is fully deprecated, remove
@ -3514,11 +3537,19 @@ def _create_process_group_wrapper(
wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
return wrapped_pg
# helper function for deterministically hashing a list of ranks
def _hash_ranks(ranks: List[int]):
return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: List[int]) -> int:
# Convert our hash to an int, but avoid negative numbers by shifting a bit.
return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
def _process_group_name(ranks, use_hashed_name):
global _world
if use_hashed_name:
pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
pg_name = _hash_ranks(ranks)
while pg_name in _world.pg_names.values():
pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
else: