diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 56f67035a5fb..a1360c8dd40f 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -28,7 +28,7 @@ class NCCLTestBase { NCCLTestBase(NCCLTestBase&& other) noexcept = default; - std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { + ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { return pg_; } @@ -39,7 +39,7 @@ class NCCLTestBase { void initialize( int rank, size_t size, - std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from = + std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from = std::nullopt) { store_ = c10::make_intrusive<::c10d::FileStore>(path_, size); @@ -52,13 +52,13 @@ class NCCLTestBase { opts->split_color = ++color_; } #endif - pg_ = std::make_unique<::c10d::ProcessGroupNCCL>( + pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>( store_, rank, size, std::move(opts)); } protected: std::string path_; - std::shared_ptr<::c10d::ProcessGroupNCCL> pg_; + ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_; std::chrono::milliseconds pgTimeout_; ::c10::intrusive_ptr<::c10d::Store> store_; int color_{1}; diff --git a/test/distributed/test_dist2.py b/test/distributed/test_dist2.py index 52ffd34e2a48..d5e925b4b2d0 100644 --- a/test/distributed/test_dist2.py +++ b/test/distributed/test_dist2.py @@ -201,6 +201,17 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase): out_range = out[i * 10 : (i + 1) * 10] self.assertEqual(out_range, torch.full_like(out_range, i + 1)) + def test_group_split(self) -> None: + group = self.new_group() + subgroup = group.split_group([0], timeout=timedelta(seconds=30)) + if self.rank == 0: + assert subgroup is not None + self.assertEqual(subgroup.size(), 1) + backend = subgroup._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=30)) + else: + self.assertEqual(subgroup, None) + class ProcessGroupGlooTest(Dist2MultiProcessTestCase): device = torch.device("cpu") diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 2efe44c86b55..f57bcb3472cc 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -350,6 +350,13 @@ class ProcessGroup: ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... + def split_group( + self, + new_ranks: list[int], + timeout: Optional[timedelta] = None, + pg_options: Optional[Backend.Options] = None, + group_desc: Optional[str] = None, + ) -> Optional[ProcessGroup]: ... def abort(self) -> None: ... def set_timeout(self, timeout: timedelta) -> None: ... def shutdown(self) -> None: ... diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index acece3d8c718..0f1c5116803f 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { // backend name // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string backend; + std::string group_name; }; explicit Backend(int rank, int size); @@ -105,6 +106,14 @@ class TORCH_API Backend : public torch::CustomClassHolder { TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); } + // Subclasses must override this method to return the backend name + virtual c10::intrusive_ptr getBackendOptions() { + TORCH_CHECK( + false, + c10::str( + "Backend ", getBackendName(), " does not implement endCoalescing")); + } + virtual c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) { @@ -379,6 +388,16 @@ class TORCH_API Backend : public torch::CustomClassHolder { " is missing implementation of enableCollectivesTiming."); } + virtual c10::intrusive_ptr splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) { + TORCH_CHECK( + false, + "Backend ", + getBackendName(), + " is missing implementation of splitBackend."); + } + bool hasHooks() const { return onCompletionHook_ != nullptr; } diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 60bb0f2d879e..8074cc98a04f 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -573,6 +573,27 @@ size_t hashTensors(const std::vector& tensors) { return hash; } +// NCCL uses Non-negative int to represent in-group according to API +// requirement. We take a list of ranks and generate a hash value based on the +// list and ensure its range of 32-bit int. +int genNcclSplitColor(const std::vector& ranks) { + // Combine the hash values using a simple reducer (std::hash + fold) + std::size_t combined_hash = std::accumulate( + ranks.begin(), + ranks.end(), + std::size_t(0), + [](std::size_t acc, int rank) { + return acc ^ + (std::hash{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2)); + }); + + // max positive value of int32_t + constexpr int32_t max_c_int = std::numeric_limits::max(); + int color = static_cast( + std::abs(static_cast(combined_hash)) % max_c_int); + return color; +} + // Default value: 30 minutes int nccl_nonblocking_timeout() { static int timeout = -2; // -2 means not initialized diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 5e61837c2353..fcd55b6a655e 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -231,6 +231,7 @@ static std::map ncclDataType = { }; TORCH_API size_t hashTensors(const std::vector& tensors); +TORCH_API int genNcclSplitColor(const std::vector& ranks); TORCH_API std::string getNcclVersion(); TORCH_API std::tuple getNcclVersionTuple(); TORCH_API int getNcclVersionNumber(); diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 83418d17acdc..197fd9014b3a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -158,6 +159,63 @@ void ProcessGroup::release_resources() { backendTypeToBackend_.clear(); } +c10::intrusive_ptr ProcessGroup::splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& desc) { + TORCH_CHECK( + ranks.size() > 0, + "Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group."); + TORCH_CHECK( + ranks.size() < static_cast(size_), + "the split group's size should be less than the world_size set by init_process_group"); + std::set ranks_set(ranks.begin(), ranks.end()); + TORCH_CHECK( + ranks_set.size() == ranks.size(), + "Split ranks should not have duplicates. Please provide a list of unique ranks to split the group."); + std::vector sorted_ranks = ranks; + std::sort(sorted_ranks.begin(), sorted_ranks.end()); + c10::intrusive_ptr newGroup; + // TODO: Figure out a better way for split group name. + std::string groupName = + c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks)); + for (const auto& pair : deviceTypeToBackendType_) { + c10::DeviceType deviceType = pair.first; + BackendType backendType = pair.second; + + auto parentBackend = getBackend(deviceType); + auto backendOpts = + opts.has_value() ? opts.value() : parentBackend->getBackendOptions(); + backendOpts->group_name = groupName; + backendOpts->timeout = + timeout.has_value() ? timeout.value() : backendOpts->timeout; + auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts); + if (splitBackend == nullptr) { + continue; + } + + // TODO: Figure out a better way for split group desc. + // TODO: We can add a new field in Backend::Options to specify the group + // desc + std::string groupDesc = desc.has_value() + ? desc.value() + : c10::str(getGroupDesc(), ":split:", incrementSplitCount()); + splitBackend->setGroupDesc(groupDesc); + + if (!newGroup) { + newGroup = c10::make_intrusive( + store_->clone(), splitBackend->getRank(), splitBackend->getSize()); + newGroup->setDefaultBackend(backendType_); + newGroup->setGroupName(groupName); + newGroup->setGroupDesc(groupDesc); + } + newGroup->setBackend(deviceType, backendType, splitBackend); + } + + return newGroup; +} + } // namespace c10d namespace { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index da4bf65f4f39..5939f23e2972 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -170,6 +170,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { } } + int64_t incrementSplitCount() { + return splitCounter_++; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -955,6 +959,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { bound_device_id_ = device; } + // This creates a new subgroup using the specified ranks. + // The current rank must be included in the list of new_ranks. + virtual c10::intrusive_ptr splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& groupDesc); + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. @@ -968,6 +980,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) BackendType backendType_; std::string pg_desc_; + int64_t splitCounter_; // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 0df6073c5d2d..30301524bc57 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -697,6 +697,35 @@ const std::vector& ProcessGroupGloo::groupRanks() const { return options_->global_ranks_in_group; } +c10::intrusive_ptr ProcessGroupGloo::splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) { + auto it = std::find(ranks.begin(), ranks.end(), rank_); + int groupRank; + if (it == ranks.end()) { + return nullptr; + } else { + groupRank = std::distance(ranks.begin(), it); + } + + auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + + // TODO: we need to get rid of globalRanksInGroup eventually. + std::vector globalRanksInGroup; + for (auto rank : ranks) { + globalRanksInGroup.emplace_back(groupRanks()[rank]); + } + glooOpts->global_ranks_in_group = std::move(globalRanksInGroup); + auto store = std::dynamic_pointer_cast(store_); + TORCH_CHECK( + store != nullptr, + "store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore."); + auto pg = c10::make_intrusive( + store->_getStore()->clone(), groupRank, ranks.size(), glooOpts); + return c10::static_intrusive_pointer_cast(pg); +} + void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index e5f1ca740288..0ba2d416aedf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend { } #endif + const c10::intrusive_ptr<::c10d::Store>& _getStore() const { + return store_; + } + protected: c10::intrusive_ptr<::c10d::Store> store_; }; @@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend { } std::vector global_ranks_in_group; - std::string group_name; std::vector> devices; int threads; }; @@ -301,6 +304,14 @@ class TORCH_API ProcessGroupGloo : public Backend { } } + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } + + c10::intrusive_ptr splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) override; + const std::vector& groupRanks() const; c10::intrusive_ptr broadcast( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index fda3879a8e8c..3dc7abbb7e54 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1311,6 +1311,45 @@ void ProcessGroupNCCL::enableCollectivesTiming() { enableTiming_.store(true); } +c10::intrusive_ptr ProcessGroupNCCL::splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) { + auto deviceIdx = guessDeviceId(); + TORCH_CHECK( + deviceIdx >= 0, + "ProcessGroupNCCL::splitBackend: rank ", + rank_, + " has no device is bound to this rank."); + auto device = at::Device(at::DeviceType::CUDA, deviceIdx); + auto it = std::find(ranks.begin(), ranks.end(), rank_); + int groupRank; + if (it == ranks.end()) { + // This rank is not in the new group, so no_color split should be called + performNocolorSplit(device); + return nullptr; + } else { + groupRank = std::distance(ranks.begin(), it); + } + + auto ncclOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options."); + + // TODO: we need to get rid of globalRanksInGroup eventually. + std::vector globalRanksInGroup; + for (auto rank : ranks) { + globalRanksInGroup.emplace_back(groupRanks()[rank]); + } + ncclOpts->split_from = + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this); + ncclOpts->global_ranks_in_group = std::move(globalRanksInGroup); + auto color = genNcclSplitColor(ranks); + ncclOpts->split_color = color; + auto pg = c10::make_intrusive( + store_->clone(), groupRank, ranks.size(), ncclOpts); + pg->eagerConnectSingleDevice(device); + return c10::static_intrusive_pointer_cast(pg); +} + bool ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index bf7ac47d8ed1..d7bb02e912c8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -541,7 +541,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Optional "parent" backend and color to create communicators from // via `ncclCommSplit` - std::shared_ptr split_from; + c10::intrusive_ptr split_from; // Color to use for `ncclCommSplit`, values: // * Non-negative value: in group; // * NCCL_SPLIT_NOCOLOR (-1): not in group; @@ -562,7 +562,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { int split_color{-2}; #endif std::vector global_ranks_in_group; - std::string group_name; }; // Helper class related to TORCH_NCCL_DESYNC_DEBUG @@ -804,6 +803,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { return options_; } + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } + const std::string getBackendName() const override { return std::string(NCCL_BACKEND_NAME); } @@ -972,6 +975,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { void enableCollectivesTiming() override; + c10::intrusive_ptr splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) override; + // Helper function for iteratively aborting communicators in the provided map void abortCommsFromMap( std::unordered_map>& ncclCommsMap, diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 3355d0feebfb..854ea596aba8 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -151,6 +151,21 @@ class PyProcessGroup : public ProcessGroup { group_desc); } + c10::intrusive_ptr splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& group_desc) override { + PYBIND11_OVERRIDE( + c10::intrusive_ptr, /* Return type */ + ProcessGroup, /* Parent class */ + splitGroup, /* Name of function in C++ */ + ranks, + timeout, + opts, + group_desc); + } + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0121bd6fd94b..5dfc99a893c7 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2063,6 +2063,14 @@ communication mechanism. .def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)") .def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)") .def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)") + .def( + "split_group", + &::c10d::ProcessGroup::splitGroup, + py::arg("ranks"), + py::arg("timeout") = std::nullopt, + py::arg("opts") = std::nullopt, + py::arg("groupDesc") = std::nullopt, + py::call_guard()) .def( "abort", &::c10d::ProcessGroup::abort,