mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Opportunistically use ncclCommSplit
when creating new NCCL groups (#112889)"
This reverts commit 64a5372e6ce9b6ca0ee5c7482b27e24561725b28.
Reverted https://github.com/pytorch/pytorch/pull/112889 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing ROCm distributed jobs in trunk 4d07428ede
([comment](https://github.com/pytorch/pytorch/pull/112889#issuecomment-1823214376))
This commit is contained in:
@ -31,20 +31,12 @@ class NCCLTestBase {
|
||||
pg_ = std::move(other.pg_);
|
||||
}
|
||||
|
||||
std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
|
||||
return pg_;
|
||||
::c10d::ProcessGroupNCCL& getProcessGroup() {
|
||||
return *pg_;
|
||||
}
|
||||
|
||||
::c10::intrusive_ptr<::c10d::Store>& getProcessGroupStore() {
|
||||
return store_;
|
||||
}
|
||||
|
||||
void initialize(
|
||||
int rank,
|
||||
int size,
|
||||
c10::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
|
||||
c10::nullopt) {
|
||||
store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
|
||||
void initialize(int rank, int size) {
|
||||
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts =
|
||||
c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
|
||||
@ -53,22 +45,14 @@ class NCCLTestBase {
|
||||
c10d::TORCH_ENABLE_NCCL_HEALTH_CHECK[0].c_str(),
|
||||
"1",
|
||||
/* overwrite */ 1);
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
if (split_from) {
|
||||
opts->split_from = *split_from;
|
||||
opts->split_color = ++color_;
|
||||
}
|
||||
#endif
|
||||
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
|
||||
new ::c10d::ProcessGroupNCCL(store_, rank, size, std::move(opts)));
|
||||
new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string path_;
|
||||
std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
|
||||
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
|
||||
std::chrono::milliseconds pgTimeout_;
|
||||
::c10::intrusive_ptr<::c10d::Store> store_;
|
||||
int color_{1};
|
||||
};
|
||||
|
||||
class NCCLTest : public NCCLTestBase {
|
||||
@ -734,9 +718,9 @@ void testSequenceNumInit(
|
||||
auto runTest = [&](int i) {
|
||||
NCCLTest test(path, worldSize);
|
||||
test.initialize(i, worldSize);
|
||||
test.getProcessGroup()->setSequenceNumberForGroup();
|
||||
test.getProcessGroup().setSequenceNumberForGroup();
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
auto seqNum = test.getProcessGroup()->getSequenceNumberForGroup();
|
||||
auto seqNum = test.getProcessGroup().getSequenceNumberForGroup();
|
||||
nums.insert(seqNum);
|
||||
};
|
||||
std::vector<std::thread> threads;
|
||||
@ -893,55 +877,11 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) {
|
||||
auto test = NCCLTestBase(file.path);
|
||||
test.initialize(rank_, size_);
|
||||
EXPECT_EQ(
|
||||
test.getProcessGroup()->getBackendName(),
|
||||
test.getProcessGroup().getBackendName(),
|
||||
std::string(c10d::NCCL_BACKEND_NAME));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ProcessGroupNCCLTest, testSplittingCommunicator) {
|
||||
if (skipTest()) {
|
||||
return;
|
||||
}
|
||||
TemporaryFile file;
|
||||
auto test1 = BroadcastNCCLTest(file.path, size_);
|
||||
test1.initialize(rank_, size_);
|
||||
|
||||
auto test2 = BroadcastNCCLTest(file.path, size_);
|
||||
test2.initialize(rank_, size_, test1.getProcessGroup());
|
||||
|
||||
// Steal the broadcast test and issue it for both of our groups.
|
||||
// This ensures consistent full collective communication. TODO:
|
||||
// maybe refactor the guts rather than copy-pasta, but it may not be
|
||||
// worth it.
|
||||
for (auto test : {&test1, &test2}) {
|
||||
const int numDevices = test->numDevices();
|
||||
// try every permutation of root rank and root tensor
|
||||
for (const auto rootRank : c10::irange(size_)) {
|
||||
for (const auto rootTensor : c10::irange(numDevices)) {
|
||||
auto work = test->run(rootRank, rootTensor);
|
||||
test->wait(work);
|
||||
|
||||
// Check results
|
||||
const auto expected = (rootRank * numDevices + rootTensor);
|
||||
const auto tensors = test->getTensors();
|
||||
for (const auto& tensor : tensors) {
|
||||
const auto* const data = tensor.data_ptr<float>();
|
||||
for (const auto k : c10::irange(tensor.numel())) {
|
||||
EXPECT_EQ(data[k], expected)
|
||||
<< "Broadcast outputs do not match expected outputs";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've run full operations on both the original and split process
|
||||
// group, ensure we saw exactly as many splits as we expected: 0 in the
|
||||
// original process group, and one per device in the second.
|
||||
EXPECT_EQ(test2.getProcessGroup()->getCommSplitCounter(), 0);
|
||||
EXPECT_EQ(test1.getProcessGroup()->getCommSplitCounter(), test1.numDevices());
|
||||
}
|
||||
|
||||
#ifdef IS_NCCL_EXP
|
||||
TEST_F(ProcessGroupNCCLTest, testSparseAllreduce) {
|
||||
if (skipTest()) {
|
||||
|
@ -1272,27 +1272,6 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
||||
# Verification
|
||||
self.assertEqual(torch.arange(self.world_size), output_t)
|
||||
|
||||
@requires_nccl()
|
||||
def test_comm_split_optimization(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
|
||||
# Test lazy splitting behavior across each per-device backend.
|
||||
for device in self.rank_to_GPU[self.rank]:
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
|
||||
# split doesn't happen unless the original process group has lazily
|
||||
# created communicators, so first verify we haven't split even when
|
||||
# making the new group and running an operation on the original pg.
|
||||
ng = c10d.new_group()
|
||||
tensor = torch.tensor([self.rank]).cuda(device)
|
||||
pg.broadcast(tensor, 0)
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
|
||||
# The new group will force a split of the original on first use.
|
||||
ng.broadcast(tensor, 0)
|
||||
self.assertEqual(backend.comm_split_count(), 1)
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
):
|
||||
@ -3697,6 +3676,7 @@ class NCCLTraceTest(MultiProcessTestCase):
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
|
@ -17,11 +17,6 @@
|
||||
#define NCCL_HAS_COMM_NONBLOCKING
|
||||
#endif
|
||||
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
(NCCL_MINOR >= 18)
|
||||
#define NCCL_HAS_COMM_SPLIT
|
||||
#endif
|
||||
|
||||
// ncclGetLastError() is enabled only for NCCL versions 2.13+
|
||||
// ncclRemoteError only exists in NCCL versions 2.13+
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||
@ -251,22 +246,6 @@ class NCCLComm {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
static std::shared_ptr<NCCLComm> split(
|
||||
NCCLComm* source,
|
||||
int color_id,
|
||||
int rank,
|
||||
ncclConfig_t& config) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommSplit(
|
||||
source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
|
||||
c10::nullopt);
|
||||
++source->ncclCommSplitCounter_;
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
|
||||
ncclUniqueId getNcclId() {
|
||||
return ncclId_;
|
||||
}
|
||||
@ -346,10 +325,6 @@ class NCCLComm {
|
||||
return aborted_;
|
||||
}
|
||||
|
||||
uint64_t getCommSplitCounter() const {
|
||||
return ncclCommSplitCounter_;
|
||||
}
|
||||
|
||||
ncclResult_t checkForNcclError() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
@ -426,7 +401,6 @@ class NCCLComm {
|
||||
// Unique nccl_id for this communicator.
|
||||
ncclUniqueId ncclId_;
|
||||
bool aborted_;
|
||||
uint64_t ncclCommSplitCounter_{0};
|
||||
ncclResult_t ncclAsyncErr_;
|
||||
mutable std::mutex mutex_;
|
||||
// Rank that this communicator corresponds to.
|
||||
|
@ -1875,40 +1875,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
int deviceIndex = devices[i].index();
|
||||
|
||||
gpuGuard.set_index(deviceIndex);
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
if (options_->split_from) {
|
||||
TORCH_CHECK(
|
||||
options_->split_color != 0,
|
||||
"Must specify a non-zero color when splitting");
|
||||
// Find a valid, healthy communicator to split from if possible.
|
||||
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
|
||||
auto& other_comms = options_->split_from->devNCCLCommMap_;
|
||||
auto dit = other_comms.find(devicesKey);
|
||||
if (dit != other_comms.end() && !dit->second.empty()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dit->second.size() == ncclComms.size(),
|
||||
"split_from->devNCCLCommMap_ should be empty or the same size as ncclComms!");
|
||||
if (dit->second[i] && !dit->second[i]->isAborted()) {
|
||||
ncclComms[i] = NCCLComm::split(
|
||||
dit->second[i].get(),
|
||||
options_->split_color,
|
||||
rank,
|
||||
options_->config);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// To simplify conditioonal nesting, just create the ncclComms[i]
|
||||
// entry if it hasn't been yet rather than untangling the
|
||||
// conditions that might have resulted in a split above.
|
||||
if (!ncclComms[i]) {
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
|
||||
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
|
||||
#else
|
||||
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
|
||||
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Creates the NCCL streams
|
||||
streamVal.push_back(
|
||||
@ -1954,6 +1925,9 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
std::make_tuple(devicesKey),
|
||||
std::make_tuple(devices.size()));
|
||||
|
||||
// Hold the lock before modifying the cache.
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
// Record the communicators based on ncclUniqueId.
|
||||
ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms);
|
||||
|
||||
@ -1997,20 +1971,9 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
it = devNCCLCommMap_.find(devicesKey);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
it != devNCCLCommMap_.end(), "Communicators not populated in cache!");
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
|
||||
uint64_t ret = 0;
|
||||
for (const auto& i : ncclIdToCommMap_) {
|
||||
for (const auto& j : i.second) {
|
||||
ret += j->getCommSplitCounter();
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Check validity of tensor
|
||||
|
@ -341,13 +341,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Configure ranks
|
||||
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
|
||||
#endif
|
||||
|
||||
// Optional "parent" backend and color to create communicators from
|
||||
// via `ncclCommSplit`
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
std::shared_ptr<ProcessGroupNCCL> split_from;
|
||||
int64_t split_color{0};
|
||||
#endif
|
||||
};
|
||||
|
||||
// If you wish to create multiple process groups, each with a potentially
|
||||
@ -516,10 +509,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// may indicate that there is some sort of collective desynchronization.
|
||||
uint64_t getSequenceNumberForGroup() override;
|
||||
|
||||
// Return the total number of splits the communicators held by this process
|
||||
// group have performed.
|
||||
uint64_t getCommSplitCounter() const;
|
||||
|
||||
void registerOnCompletionHook(
|
||||
std::function<void(std::shared_ptr<WorkInfo>)>&& hook) override;
|
||||
void waitForPendingWorks() override;
|
||||
|
@ -2290,9 +2290,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
|
||||
.def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
|
||||
.def(
|
||||
"comm_split_count",
|
||||
&::c10d::ProcessGroupNCCL::getCommSplitCounter)
|
||||
.def_property_readonly(
|
||||
"options", &::c10d::ProcessGroupNCCL::getOptions);
|
||||
|
||||
@ -2355,18 +2352,15 @@ Example::
|
||||
)")
|
||||
.def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
|
||||
#ifdef NCCL_HAS_COMM_CTA_CGA
|
||||
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
|
||||
#endif
|
||||
.def_readwrite(
|
||||
"is_high_priority_stream",
|
||||
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
|
||||
#ifdef NCCL_HAS_COMM_SPLIT
|
||||
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config);
|
||||
#else
|
||||
.def_readwrite(
|
||||
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
|
||||
.def_readwrite(
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
|
||||
"is_high_priority_stream",
|
||||
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
|
||||
#endif
|
||||
;
|
||||
|
||||
#endif
|
||||
|
||||
|
@ -8,7 +8,6 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
@ -1315,29 +1314,7 @@ def _new_process_group_helper(
|
||||
pg_options.is_high_priority_stream = False
|
||||
pg_options._timeout = timeout
|
||||
|
||||
# If our new group includes all ranks, we can reduce
|
||||
# overhead by splitting the communicator (`nccCommSplit`).
|
||||
|
||||
# TODO: support this in the general case by calling
|
||||
# `nccCommSplit` with `NCCL_SPLIT_NOCOLOR` for the ranks
|
||||
# not in the communicator.
|
||||
split_from = None
|
||||
if (
|
||||
is_initialized()
|
||||
and _world.default_pg._get_backend_name() == Backend.NCCL
|
||||
and len(global_ranks_in_group) == _world.default_pg.size()
|
||||
):
|
||||
# If possible, find a backend to split from by peeling
|
||||
# process group wrappers from the world's default pg.
|
||||
split_from = _world.default_pg._get_backend(_get_pg_default_device())
|
||||
while isinstance(split_from, _ProcessGroupWrapper):
|
||||
split_from = split_from.wrapped_pg
|
||||
|
||||
if split_from:
|
||||
pg_options.split_from = split_from
|
||||
pg_options.split_color = _process_group_color(global_ranks_in_group)
|
||||
backend_class = ProcessGroupNCCL(
|
||||
backend_prefix_store, group_rank, group_size, pg_options)
|
||||
backend_class = ProcessGroupNCCL(backend_prefix_store, group_rank, group_size, pg_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
elif backend_str == Backend.UCC and is_ucc_available():
|
||||
# TODO: once UCC plugin is fully deprecated, remove
|
||||
@ -3537,19 +3514,11 @@ def _create_process_group_wrapper(
|
||||
wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
|
||||
return wrapped_pg
|
||||
|
||||
# helper function for deterministically hashing a list of ranks
|
||||
def _hash_ranks(ranks: List[int]):
|
||||
return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
|
||||
|
||||
# Takes a list of ranks and computes an integer color
|
||||
def _process_group_color(ranks: List[int]) -> int:
|
||||
# Convert our hash to an int, but avoid negative numbers by shifting a bit.
|
||||
return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
|
||||
|
||||
def _process_group_name(ranks, use_hashed_name):
|
||||
global _world
|
||||
if use_hashed_name:
|
||||
pg_name = _hash_ranks(ranks)
|
||||
pg_name = hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
|
||||
while pg_name in _world.pg_names.values():
|
||||
pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
|
||||
else:
|
||||
|
Reference in New Issue
Block a user