mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d]Prototype of remote_group_merge (#158287)
Tentative implementation of merge_remote_group per the proposal here: [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158287 Approved by: https://github.com/d4l3k ghstack dependencies: #157716
This commit is contained in:
@ -5,6 +5,7 @@ import unittest
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._dist2 as dist2
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
@ -216,6 +217,39 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
|
||||
else:
|
||||
self.assertEqual(subgroup, None)
|
||||
|
||||
def test_remote_group_merge(self) -> None:
|
||||
group = self.new_group()
|
||||
subgroup_1 = group.split_group([0], timeout=timedelta(seconds=30))
|
||||
subgroup_2 = group.split_group([1], timeout=timedelta(seconds=30))
|
||||
if self.rank == 0:
|
||||
assert subgroup_1 is not None
|
||||
tcp_store = dist.TCPStore(
|
||||
host_name=os.environ["MASTER_ADDR"],
|
||||
port=29781,
|
||||
world_size=2,
|
||||
is_master=True,
|
||||
)
|
||||
merged_pg = subgroup_1.merge_remote_group(
|
||||
tcp_store, 2, timedelta(seconds=40), "merged_pg"
|
||||
)
|
||||
self.assertEqual(merged_pg.size(), 2)
|
||||
backend = merged_pg._get_backend(self.device)
|
||||
self.assertEqual(backend.options._timeout, timedelta(seconds=40))
|
||||
else:
|
||||
assert subgroup_2 is not None
|
||||
tcp_store = dist.TCPStore(
|
||||
host_name=os.environ["MASTER_ADDR"],
|
||||
port=29781,
|
||||
world_size=2,
|
||||
is_master=False,
|
||||
)
|
||||
merged_pg = subgroup_2.merge_remote_group(
|
||||
tcp_store, 2, timedelta(seconds=40), "merged_pg"
|
||||
)
|
||||
self.assertEqual(merged_pg.size(), 2)
|
||||
backend = merged_pg._get_backend(self.device)
|
||||
self.assertEqual(backend.options._timeout, timedelta(seconds=40))
|
||||
|
||||
|
||||
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
|
||||
device = torch.device("cpu")
|
||||
|
@ -357,6 +357,14 @@ class ProcessGroup:
|
||||
pg_options: Optional[Backend.Options] = None,
|
||||
group_desc: Optional[str] = None,
|
||||
) -> Optional[ProcessGroup]: ...
|
||||
def merge_remote_group(
|
||||
self,
|
||||
store: Store,
|
||||
size: int,
|
||||
timeout: timedelta,
|
||||
group_name: Optional[str] = None,
|
||||
group_desc: Optional[str] = None,
|
||||
) -> ProcessGroup: ...
|
||||
def abort(self) -> None: ...
|
||||
def set_timeout(self, timeout: timedelta) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
@ -388,14 +388,26 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
" is missing implementation of enableCollectivesTiming.");
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Backend> splitBackend(
|
||||
virtual c10::intrusive_ptr<Backend> split(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Options> opts) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Backend ",
|
||||
getBackendName(),
|
||||
" is missing implementation of splitBackend.");
|
||||
" is missing implementation of split.");
|
||||
}
|
||||
|
||||
virtual c10::intrusive_ptr<Backend> merge(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const c10::intrusive_ptr<Options> opts,
|
||||
const int& rank,
|
||||
const int& size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Backend ",
|
||||
getBackendName(),
|
||||
" is missing implementation of merge.");
|
||||
}
|
||||
|
||||
bool hasHooks() const {
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ranges.h>
|
||||
#include <string_view>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
|
||||
@ -190,7 +189,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||
backendOpts->group_name = groupName;
|
||||
backendOpts->timeout =
|
||||
timeout.has_value() ? timeout.value() : backendOpts->timeout;
|
||||
auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts);
|
||||
auto splitBackend = parentBackend->split(sorted_ranks, backendOpts);
|
||||
if (splitBackend == nullptr) {
|
||||
continue;
|
||||
}
|
||||
@ -216,6 +215,47 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
|
||||
return newGroup;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup> ProcessGroup::mergeRemoteGroup(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const MergeOptions& opts,
|
||||
const int& size) {
|
||||
c10::intrusive_ptr<ProcessGroup> newGroup;
|
||||
// We assume rank number is within the range of int32_t, so it won't overflow.
|
||||
int rank = static_cast<int>(store->add("mergeGroupRank", 1) - 1);
|
||||
// TODO: Do we need to check all groups have same deviceTypeToBackendType_?
|
||||
for (const auto& pair : deviceTypeToBackendType_) {
|
||||
c10::DeviceType deviceType = pair.first;
|
||||
BackendType backendType = pair.second;
|
||||
|
||||
auto parentBackend = getBackend(deviceType);
|
||||
auto backendOpts = parentBackend->getBackendOptions();
|
||||
std::string groupName = opts.group_name.has_value()
|
||||
? opts.group_name.value()
|
||||
: c10::str(getGroupName(), ":merge");
|
||||
backendOpts->group_name = groupName;
|
||||
backendOpts->timeout = opts.timeout;
|
||||
auto mergedBackend = parentBackend->merge(store, backendOpts, rank, size);
|
||||
|
||||
std::string groupDesc = opts.group_desc.has_value()
|
||||
? opts.group_desc.value()
|
||||
: c10::str(getGroupDesc(), ":merge");
|
||||
mergedBackend->setGroupDesc(groupDesc);
|
||||
|
||||
// Historically, we have been using one process_group to map to all
|
||||
// backends. but in our new design, we will have one process_group per
|
||||
// backend. This logic is mostly for backward compatibility.
|
||||
if (!newGroup) {
|
||||
newGroup = c10::make_intrusive<ProcessGroup>(store, rank, size);
|
||||
newGroup->setDefaultBackend(backendType_);
|
||||
newGroup->setGroupName(groupName);
|
||||
newGroup->setGroupDesc(groupDesc);
|
||||
}
|
||||
newGroup->setBackend(deviceType, backendType, mergedBackend);
|
||||
}
|
||||
|
||||
return newGroup;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
namespace {
|
||||
|
@ -71,6 +71,21 @@ C10_EXPORT bool allow_inflight_collective_as_graph_input();
|
||||
//
|
||||
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
public:
|
||||
struct TORCH_API MergeOptions : torch::CustomClassHolder {
|
||||
explicit MergeOptions(
|
||||
const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout,
|
||||
const std::optional<std::string> group_name = std::nullopt,
|
||||
const std::optional<std::string> group_desc = std::nullopt)
|
||||
: timeout(timeout), group_name(group_name), group_desc(group_desc) {}
|
||||
~MergeOptions() override = default;
|
||||
MergeOptions(const MergeOptions&) = delete;
|
||||
MergeOptions& operator=(const MergeOptions&) = delete;
|
||||
|
||||
std::chrono::milliseconds timeout;
|
||||
std::optional<std::string> group_name;
|
||||
std::optional<std::string> group_desc;
|
||||
};
|
||||
|
||||
enum BackendType : uint8_t {
|
||||
UNDEFINED = 0,
|
||||
GLOO = 1,
|
||||
@ -967,6 +982,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
|
||||
const std::optional<std::string>& groupDesc);
|
||||
|
||||
// This creates a new subgroup using the specified ranks.
|
||||
// The current rank must be included in the list of new_ranks.
|
||||
virtual c10::intrusive_ptr<ProcessGroup> mergeRemoteGroup(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const MergeOptions& opts,
|
||||
const int& size);
|
||||
|
||||
protected:
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
|
@ -697,7 +697,7 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
|
||||
return options_->global_ranks_in_group;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
|
||||
c10::intrusive_ptr<Backend> ProcessGroupGloo::split(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) {
|
||||
auto it = std::find(ranks.begin(), ranks.end(), rank_);
|
||||
@ -726,6 +726,18 @@ c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
|
||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> ProcessGroupGloo::merge(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const c10::intrusive_ptr<Backend::Options> opts,
|
||||
const int& rank,
|
||||
const int& size) {
|
||||
auto glooOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
|
||||
TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options.");
|
||||
auto pg = c10::make_intrusive<ProcessGroupGloo>(
|
||||
store->clone(), rank, size, glooOpts);
|
||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||
}
|
||||
|
||||
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
std::unique_lock<std::mutex> lock(workMutex_);
|
||||
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);
|
||||
|
@ -308,10 +308,16 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> splitBackend(
|
||||
c10::intrusive_ptr<Backend> split(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) override;
|
||||
|
||||
c10::intrusive_ptr<Backend> merge(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const c10::intrusive_ptr<Backend::Options> opts,
|
||||
const int& rank,
|
||||
const int& size) override;
|
||||
|
||||
const std::vector<uint64_t>& groupRanks() const;
|
||||
|
||||
c10::intrusive_ptr<Work> broadcast(
|
||||
|
@ -1311,13 +1311,13 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
|
||||
enableTiming_.store(true);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::split(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) {
|
||||
auto deviceIdx = guessDeviceId();
|
||||
TORCH_CHECK(
|
||||
deviceIdx >= 0,
|
||||
"ProcessGroupNCCL::splitBackend: rank ",
|
||||
"ProcessGroupNCCL::split: rank ",
|
||||
rank_,
|
||||
" has no device is bound to this rank.");
|
||||
auto device = at::Device(at::DeviceType::CUDA, deviceIdx);
|
||||
@ -1350,6 +1350,18 @@ c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
|
||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::merge(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const c10::intrusive_ptr<Backend::Options> opts,
|
||||
const int& rank,
|
||||
const int& size) {
|
||||
auto ncclOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
|
||||
TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options.");
|
||||
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
|
||||
store->clone(), rank, size, ncclOpts);
|
||||
return c10::static_intrusive_pointer_cast<Backend>(pg);
|
||||
}
|
||||
|
||||
bool ProcessGroupNCCL::waitForFutureOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::milliseconds& timeOutMilSec,
|
||||
|
@ -975,10 +975,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
void enableCollectivesTiming() override;
|
||||
|
||||
c10::intrusive_ptr<Backend> splitBackend(
|
||||
c10::intrusive_ptr<Backend> split(
|
||||
const std::vector<int>& ranks,
|
||||
const c10::intrusive_ptr<Backend::Options> opts) override;
|
||||
|
||||
c10::intrusive_ptr<Backend> merge(
|
||||
const c10::intrusive_ptr<Store>& store,
|
||||
const c10::intrusive_ptr<Backend::Options> opts,
|
||||
const int& rank,
|
||||
const int& size) override;
|
||||
|
||||
// Helper function for iteratively aborting communicators in the provided map
|
||||
void abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
|
||||
|
@ -166,6 +166,19 @@ class PyProcessGroup : public ProcessGroup {
|
||||
group_desc);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup> mergeRemoteGroup(
|
||||
const c10::intrusive_ptr<c10d::Store>& store,
|
||||
const MergeOptions& opts,
|
||||
const int& size) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
c10::intrusive_ptr<ProcessGroup>, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
mergeRemoteGroup, /* Name of function in C++ */
|
||||
store,
|
||||
opts,
|
||||
size);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> allgather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
|
@ -2071,6 +2071,26 @@ communication mechanism.
|
||||
py::arg("opts") = std::nullopt,
|
||||
py::arg("groupDesc") = std::nullopt,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"merge_remote_group",
|
||||
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
int size,
|
||||
std::chrono::milliseconds timeout,
|
||||
std::optional<std::string> groupName,
|
||||
std::optional<std::string> groupDesc) {
|
||||
::c10d::ProcessGroup::MergeOptions opts;
|
||||
opts.timeout = timeout;
|
||||
opts.group_name = groupName;
|
||||
opts.group_desc = groupDesc;
|
||||
return self->mergeRemoteGroup(store, opts, size);
|
||||
},
|
||||
py::arg("store"),
|
||||
py::arg("size"),
|
||||
py::arg("timeout") = kProcessGroupDefaultTimeout,
|
||||
py::arg("group_name") = std::nullopt,
|
||||
py::arg("group_desc") = std::nullopt,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::ProcessGroup::abort,
|
||||
|
Reference in New Issue
Block a user