[c10d]Prototype of remote_group_merge (#158287)

Tentative implementation of merge_remote_group per the proposal here: [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158287
Approved by: https://github.com/d4l3k
ghstack dependencies: #157716
This commit is contained in:
fduwjj
2025-07-16 07:13:57 -07:00
committed by PyTorch MergeBot
parent 944a140e90
commit f58a680d09
11 changed files with 194 additions and 9 deletions

View File

@ -5,6 +5,7 @@ import unittest
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.distributed._dist2 as dist2
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
@ -216,6 +217,39 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
else:
self.assertEqual(subgroup, None)
def test_remote_group_merge(self) -> None:
group = self.new_group()
subgroup_1 = group.split_group([0], timeout=timedelta(seconds=30))
subgroup_2 = group.split_group([1], timeout=timedelta(seconds=30))
if self.rank == 0:
assert subgroup_1 is not None
tcp_store = dist.TCPStore(
host_name=os.environ["MASTER_ADDR"],
port=29781,
world_size=2,
is_master=True,
)
merged_pg = subgroup_1.merge_remote_group(
tcp_store, 2, timedelta(seconds=40), "merged_pg"
)
self.assertEqual(merged_pg.size(), 2)
backend = merged_pg._get_backend(self.device)
self.assertEqual(backend.options._timeout, timedelta(seconds=40))
else:
assert subgroup_2 is not None
tcp_store = dist.TCPStore(
host_name=os.environ["MASTER_ADDR"],
port=29781,
world_size=2,
is_master=False,
)
merged_pg = subgroup_2.merge_remote_group(
tcp_store, 2, timedelta(seconds=40), "merged_pg"
)
self.assertEqual(merged_pg.size(), 2)
backend = merged_pg._get_backend(self.device)
self.assertEqual(backend.options._timeout, timedelta(seconds=40))
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
device = torch.device("cpu")

View File

@ -357,6 +357,14 @@ class ProcessGroup:
pg_options: Optional[Backend.Options] = None,
group_desc: Optional[str] = None,
) -> Optional[ProcessGroup]: ...
def merge_remote_group(
self,
store: Store,
size: int,
timeout: timedelta,
group_name: Optional[str] = None,
group_desc: Optional[str] = None,
) -> ProcessGroup: ...
def abort(self) -> None: ...
def set_timeout(self, timeout: timedelta) -> None: ...
def shutdown(self) -> None: ...

View File

@ -388,14 +388,26 @@ class TORCH_API Backend : public torch::CustomClassHolder {
" is missing implementation of enableCollectivesTiming.");
}
virtual c10::intrusive_ptr<Backend> splitBackend(
virtual c10::intrusive_ptr<Backend> split(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Options> opts) {
TORCH_CHECK(
false,
"Backend ",
getBackendName(),
" is missing implementation of splitBackend.");
" is missing implementation of split.");
}
virtual c10::intrusive_ptr<Backend> merge(
const c10::intrusive_ptr<Store>& store,
const c10::intrusive_ptr<Options> opts,
const int& rank,
const int& size) {
TORCH_CHECK(
false,
"Backend ",
getBackendName(),
" is missing implementation of merge.");
}
bool hasHooks() const {

View File

@ -5,7 +5,6 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <string_view>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
@ -190,7 +189,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
backendOpts->group_name = groupName;
backendOpts->timeout =
timeout.has_value() ? timeout.value() : backendOpts->timeout;
auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts);
auto splitBackend = parentBackend->split(sorted_ranks, backendOpts);
if (splitBackend == nullptr) {
continue;
}
@ -216,6 +215,47 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
return newGroup;
}
c10::intrusive_ptr<ProcessGroup> ProcessGroup::mergeRemoteGroup(
const c10::intrusive_ptr<Store>& store,
const MergeOptions& opts,
const int& size) {
c10::intrusive_ptr<ProcessGroup> newGroup;
// We assume rank number is within the range of int32_t, so it won't overflow.
int rank = static_cast<int>(store->add("mergeGroupRank", 1) - 1);
// TODO: Do we need to check all groups have same deviceTypeToBackendType_?
for (const auto& pair : deviceTypeToBackendType_) {
c10::DeviceType deviceType = pair.first;
BackendType backendType = pair.second;
auto parentBackend = getBackend(deviceType);
auto backendOpts = parentBackend->getBackendOptions();
std::string groupName = opts.group_name.has_value()
? opts.group_name.value()
: c10::str(getGroupName(), ":merge");
backendOpts->group_name = groupName;
backendOpts->timeout = opts.timeout;
auto mergedBackend = parentBackend->merge(store, backendOpts, rank, size);
std::string groupDesc = opts.group_desc.has_value()
? opts.group_desc.value()
: c10::str(getGroupDesc(), ":merge");
mergedBackend->setGroupDesc(groupDesc);
// Historically, we have been using one process_group to map to all
// backends. but in our new design, we will have one process_group per
// backend. This logic is mostly for backward compatibility.
if (!newGroup) {
newGroup = c10::make_intrusive<ProcessGroup>(store, rank, size);
newGroup->setDefaultBackend(backendType_);
newGroup->setGroupName(groupName);
newGroup->setGroupDesc(groupDesc);
}
newGroup->setBackend(deviceType, backendType, mergedBackend);
}
return newGroup;
}
} // namespace c10d
namespace {

View File

@ -71,6 +71,21 @@ C10_EXPORT bool allow_inflight_collective_as_graph_input();
//
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
public:
struct TORCH_API MergeOptions : torch::CustomClassHolder {
explicit MergeOptions(
const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout,
const std::optional<std::string> group_name = std::nullopt,
const std::optional<std::string> group_desc = std::nullopt)
: timeout(timeout), group_name(group_name), group_desc(group_desc) {}
~MergeOptions() override = default;
MergeOptions(const MergeOptions&) = delete;
MergeOptions& operator=(const MergeOptions&) = delete;
std::chrono::milliseconds timeout;
std::optional<std::string> group_name;
std::optional<std::string> group_desc;
};
enum BackendType : uint8_t {
UNDEFINED = 0,
GLOO = 1,
@ -967,6 +982,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& groupDesc);
// This creates a new subgroup using the specified ranks.
// The current rank must be included in the list of new_ranks.
virtual c10::intrusive_ptr<ProcessGroup> mergeRemoteGroup(
const c10::intrusive_ptr<Store>& store,
const MergeOptions& opts,
const int& size);
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.

View File

@ -697,7 +697,7 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
return options_->global_ranks_in_group;
}
c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) {
auto it = std::find(ranks.begin(), ranks.end(), rank_);
@ -726,6 +726,18 @@ c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
c10::intrusive_ptr<Backend> ProcessGroupGloo::merge(
const c10::intrusive_ptr<Store>& store,
const c10::intrusive_ptr<Backend::Options> opts,
const int& rank,
const int& size) {
auto glooOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options.");
auto pg = c10::make_intrusive<ProcessGroupGloo>(
store->clone(), rank, size, glooOpts);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);

View File

@ -308,10 +308,16 @@ class TORCH_API ProcessGroupGloo : public Backend {
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
}
c10::intrusive_ptr<Backend> splitBackend(
c10::intrusive_ptr<Backend> split(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) override;
c10::intrusive_ptr<Backend> merge(
const c10::intrusive_ptr<Store>& store,
const c10::intrusive_ptr<Backend::Options> opts,
const int& rank,
const int& size) override;
const std::vector<uint64_t>& groupRanks() const;
c10::intrusive_ptr<Work> broadcast(

View File

@ -1311,13 +1311,13 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
enableTiming_.store(true);
}
c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) {
auto deviceIdx = guessDeviceId();
TORCH_CHECK(
deviceIdx >= 0,
"ProcessGroupNCCL::splitBackend: rank ",
"ProcessGroupNCCL::split: rank ",
rank_,
" has no device is bound to this rank.");
auto device = at::Device(at::DeviceType::CUDA, deviceIdx);
@ -1350,6 +1350,18 @@ c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
c10::intrusive_ptr<Backend> ProcessGroupNCCL::merge(
const c10::intrusive_ptr<Store>& store,
const c10::intrusive_ptr<Backend::Options> opts,
const int& rank,
const int& size) {
auto ncclOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options.");
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
store->clone(), rank, size, ncclOpts);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
bool ProcessGroupNCCL::waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,

View File

@ -975,10 +975,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void enableCollectivesTiming() override;
c10::intrusive_ptr<Backend> splitBackend(
c10::intrusive_ptr<Backend> split(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) override;
c10::intrusive_ptr<Backend> merge(
const c10::intrusive_ptr<Store>& store,
const c10::intrusive_ptr<Backend::Options> opts,
const int& rank,
const int& size) override;
// Helper function for iteratively aborting communicators in the provided map
void abortCommsFromMap(
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,

View File

@ -166,6 +166,19 @@ class PyProcessGroup : public ProcessGroup {
group_desc);
}
c10::intrusive_ptr<ProcessGroup> mergeRemoteGroup(
const c10::intrusive_ptr<c10d::Store>& store,
const MergeOptions& opts,
const int& size) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup>, /* Return type */
ProcessGroup, /* Parent class */
mergeRemoteGroup, /* Name of function in C++ */
store,
opts,
size);
}
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,

View File

@ -2071,6 +2071,26 @@ communication mechanism.
py::arg("opts") = std::nullopt,
py::arg("groupDesc") = std::nullopt,
py::call_guard<py::gil_scoped_release>())
.def(
"merge_remote_group",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
const c10::intrusive_ptr<::c10d::Store>& store,
int size,
std::chrono::milliseconds timeout,
std::optional<std::string> groupName,
std::optional<std::string> groupDesc) {
::c10d::ProcessGroup::MergeOptions opts;
opts.timeout = timeout;
opts.group_name = groupName;
opts.group_desc = groupDesc;
return self->mergeRemoteGroup(store, opts, size);
},
py::arg("store"),
py::arg("size"),
py::arg("timeout") = kProcessGroupDefaultTimeout,
py::arg("group_name") = std::nullopt,
py::arg("group_desc") = std::nullopt,
py::call_guard<py::gil_scoped_release>())
.def(
"abort",
&::c10d::ProcessGroup::abort,