[c10d] Prototype of group_split for dist2 work (#157716)

This is to implement group_split as proposed in [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157716
Approved by: https://github.com/d4l3k
This commit is contained in:
fduwjj
2025-07-14 10:00:40 -07:00
committed by PyTorch MergeBot
parent 1e4d8b5a4a
commit 6b2bef10af
14 changed files with 246 additions and 7 deletions

View File

@ -28,7 +28,7 @@ class NCCLTestBase {
NCCLTestBase(NCCLTestBase&& other) noexcept = default;
std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() {
return pg_;
}
@ -39,7 +39,7 @@ class NCCLTestBase {
void initialize(
int rank,
size_t size,
std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from =
std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from =
std::nullopt) {
store_ = c10::make_intrusive<::c10d::FileStore>(path_, size);
@ -52,13 +52,13 @@ class NCCLTestBase {
opts->split_color = ++color_;
}
#endif
pg_ = std::make_unique<::c10d::ProcessGroupNCCL>(
pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>(
store_, rank, size, std::move(opts));
}
protected:
std::string path_;
std::shared_ptr<::c10d::ProcessGroupNCCL> pg_;
::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_;
std::chrono::milliseconds pgTimeout_;
::c10::intrusive_ptr<::c10d::Store> store_;
int color_{1};

View File

@ -201,6 +201,17 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
out_range = out[i * 10 : (i + 1) * 10]
self.assertEqual(out_range, torch.full_like(out_range, i + 1))
def test_group_split(self) -> None:
group = self.new_group()
subgroup = group.split_group([0], timeout=timedelta(seconds=30))
if self.rank == 0:
assert subgroup is not None
self.assertEqual(subgroup.size(), 1)
backend = subgroup._get_backend(self.device)
self.assertEqual(backend.options._timeout, timedelta(seconds=30))
else:
self.assertEqual(subgroup, None)
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
device = torch.device("cpu")

View File

@ -350,6 +350,13 @@ class ProcessGroup:
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def split_group(
self,
new_ranks: list[int],
timeout: Optional[timedelta] = None,
pg_options: Optional[Backend.Options] = None,
group_desc: Optional[str] = None,
) -> Optional[ProcessGroup]: ...
def abort(self) -> None: ...
def set_timeout(self, timeout: timedelta) -> None: ...
def shutdown(self) -> None: ...

View File

@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// backend name
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string backend;
std::string group_name;
};
explicit Backend(int rank, int size);
@ -105,6 +106,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
}
// Subclasses must override this method to return the backend name
virtual c10::intrusive_ptr<Options> getBackendOptions() {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not implement endCoalescing"));
}
virtual c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& /* tensors */,
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
@ -379,6 +388,16 @@ class TORCH_API Backend : public torch::CustomClassHolder {
" is missing implementation of enableCollectivesTiming.");
}
virtual c10::intrusive_ptr<Backend> splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Options> opts) {
TORCH_CHECK(
false,
"Backend ",
getBackendName(),
" is missing implementation of splitBackend.");
}
bool hasHooks() const {
return onCompletionHook_ != nullptr;
}

View File

@ -573,6 +573,27 @@ size_t hashTensors(const std::vector<at::Tensor>& tensors) {
return hash;
}
// NCCL uses Non-negative int to represent in-group according to API
// requirement. We take a list of ranks and generate a hash value based on the
// list and ensure its range of 32-bit int.
int genNcclSplitColor(const std::vector<int>& ranks) {
// Combine the hash values using a simple reducer (std::hash + fold)
std::size_t combined_hash = std::accumulate(
ranks.begin(),
ranks.end(),
std::size_t(0),
[](std::size_t acc, int rank) {
return acc ^
(std::hash<int>{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2));
});
// max positive value of int32_t
constexpr int32_t max_c_int = std::numeric_limits<int32_t>::max();
int color = static_cast<int>(
std::abs(static_cast<int64_t>(combined_hash)) % max_c_int);
return color;
}
// Default value: 30 minutes
int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized

View File

@ -231,6 +231,7 @@ static std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
};
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
TORCH_API int genNcclSplitColor(const std::vector<int>& ranks);
TORCH_API std::string getNcclVersion();
TORCH_API std::tuple<int, int, int> getNcclVersionTuple();
TORCH_API int getNcclVersionNumber();

View File

@ -4,6 +4,7 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <string_view>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
@ -158,6 +159,63 @@ void ProcessGroup::release_resources() {
backendTypeToBackend_.clear();
}
c10::intrusive_ptr<ProcessGroup> ProcessGroup::splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& desc) {
TORCH_CHECK(
ranks.size() > 0,
"Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group.");
TORCH_CHECK(
ranks.size() < static_cast<size_t>(size_),
"the split group's size should be less than the world_size set by init_process_group");
std::set<int> ranks_set(ranks.begin(), ranks.end());
TORCH_CHECK(
ranks_set.size() == ranks.size(),
"Split ranks should not have duplicates. Please provide a list of unique ranks to split the group.");
std::vector<int> sorted_ranks = ranks;
std::sort(sorted_ranks.begin(), sorted_ranks.end());
c10::intrusive_ptr<ProcessGroup> newGroup;
// TODO: Figure out a better way for split group name.
std::string groupName =
c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks));
for (const auto& pair : deviceTypeToBackendType_) {
c10::DeviceType deviceType = pair.first;
BackendType backendType = pair.second;
auto parentBackend = getBackend(deviceType);
auto backendOpts =
opts.has_value() ? opts.value() : parentBackend->getBackendOptions();
backendOpts->group_name = groupName;
backendOpts->timeout =
timeout.has_value() ? timeout.value() : backendOpts->timeout;
auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts);
if (splitBackend == nullptr) {
continue;
}
// TODO: Figure out a better way for split group desc.
// TODO: We can add a new field in Backend::Options to specify the group
// desc
std::string groupDesc = desc.has_value()
? desc.value()
: c10::str(getGroupDesc(), ":split:", incrementSplitCount());
splitBackend->setGroupDesc(groupDesc);
if (!newGroup) {
newGroup = c10::make_intrusive<ProcessGroup>(
store_->clone(), splitBackend->getRank(), splitBackend->getSize());
newGroup->setDefaultBackend(backendType_);
newGroup->setGroupName(groupName);
newGroup->setGroupDesc(groupDesc);
}
newGroup->setBackend(deviceType, backendType, splitBackend);
}
return newGroup;
}
} // namespace c10d
namespace {

View File

@ -170,6 +170,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
}
}
int64_t incrementSplitCount() {
return splitCounter_++;
}
virtual void startCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented startCoalescing so only execute for nccl
// backends
@ -955,6 +959,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
bound_device_id_ = device;
}
// This creates a new subgroup using the specified ranks.
// The current rank must be included in the list of new_ranks.
virtual c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& groupDesc);
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
@ -968,6 +980,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
BackendType backendType_;
std::string pg_desc_;
int64_t splitCounter_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.

View File

@ -697,6 +697,35 @@ const std::vector<uint64_t>& ProcessGroupGloo::groupRanks() const {
return options_->global_ranks_in_group;
}
c10::intrusive_ptr<Backend> ProcessGroupGloo::splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) {
auto it = std::find(ranks.begin(), ranks.end(), rank_);
int groupRank;
if (it == ranks.end()) {
return nullptr;
} else {
groupRank = std::distance(ranks.begin(), it);
}
auto glooOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options.");
// TODO: we need to get rid of globalRanksInGroup eventually.
std::vector<uint64_t> globalRanksInGroup;
for (auto rank : ranks) {
globalRanksInGroup.emplace_back(groupRanks()[rank]);
}
glooOpts->global_ranks_in_group = std::move(globalRanksInGroup);
auto store = std::dynamic_pointer_cast<GlooStore>(store_);
TORCH_CHECK(
store != nullptr,
"store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore.");
auto pg = c10::make_intrusive<ProcessGroupGloo>(
store->_getStore()->clone(), groupRank, ranks.size(), glooOpts);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);

View File

@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
#endif
const c10::intrusive_ptr<::c10d::Store>& _getStore() const {
return store_;
}
protected:
c10::intrusive_ptr<::c10d::Store> store_;
};
@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
int threads;
};
@ -301,6 +304,14 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
}
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
}
c10::intrusive_ptr<Backend> splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) override;
const std::vector<uint64_t>& groupRanks() const;
c10::intrusive_ptr<Work> broadcast(

View File

@ -1311,6 +1311,45 @@ void ProcessGroupNCCL::enableCollectivesTiming() {
enableTiming_.store(true);
}
c10::intrusive_ptr<Backend> ProcessGroupNCCL::splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) {
auto deviceIdx = guessDeviceId();
TORCH_CHECK(
deviceIdx >= 0,
"ProcessGroupNCCL::splitBackend: rank ",
rank_,
" has no device is bound to this rank.");
auto device = at::Device(at::DeviceType::CUDA, deviceIdx);
auto it = std::find(ranks.begin(), ranks.end(), rank_);
int groupRank;
if (it == ranks.end()) {
// This rank is not in the new group, so no_color split should be called
performNocolorSplit(device);
return nullptr;
} else {
groupRank = std::distance(ranks.begin(), it);
}
auto ncclOpts = c10::dynamic_intrusive_pointer_cast<Options>(opts);
TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options.");
// TODO: we need to get rid of globalRanksInGroup eventually.
std::vector<uint64_t> globalRanksInGroup;
for (auto rank : ranks) {
globalRanksInGroup.emplace_back(groupRanks()[rank]);
}
ncclOpts->split_from =
c10::intrusive_ptr<ProcessGroupNCCL>::unsafe_reclaim_from_nonowning(this);
ncclOpts->global_ranks_in_group = std::move(globalRanksInGroup);
auto color = genNcclSplitColor(ranks);
ncclOpts->split_color = color;
auto pg = c10::make_intrusive<ProcessGroupNCCL>(
store_->clone(), groupRank, ranks.size(), ncclOpts);
pg->eagerConnectSingleDevice(device);
return c10::static_intrusive_pointer_cast<Backend>(pg);
}
bool ProcessGroupNCCL::waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,

View File

@ -541,7 +541,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Optional "parent" backend and color to create communicators from
// via `ncclCommSplit`
std::shared_ptr<ProcessGroupNCCL> split_from;
c10::intrusive_ptr<ProcessGroupNCCL> split_from;
// Color to use for `ncclCommSplit`, values:
// * Non-negative value: in group;
// * NCCL_SPLIT_NOCOLOR (-1): not in group;
@ -562,7 +562,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
int split_color{-2};
#endif
std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
};
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
@ -804,6 +803,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
return options_;
}
c10::intrusive_ptr<Backend::Options> getBackendOptions() override {
return c10::static_intrusive_pointer_cast<Backend::Options>(options_);
}
const std::string getBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
}
@ -972,6 +975,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void enableCollectivesTiming() override;
c10::intrusive_ptr<Backend> splitBackend(
const std::vector<int>& ranks,
const c10::intrusive_ptr<Backend::Options> opts) override;
// Helper function for iteratively aborting communicators in the provided map
void abortCommsFromMap(
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,

View File

@ -151,6 +151,21 @@ class PyProcessGroup : public ProcessGroup {
group_desc);
}
c10::intrusive_ptr<ProcessGroup> splitGroup(
const std::vector<int>& ranks,
const std::optional<std::chrono::milliseconds> timeout,
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& group_desc) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup>, /* Return type */
ProcessGroup, /* Parent class */
splitGroup, /* Name of function in C++ */
ranks,
timeout,
opts,
group_desc);
}
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,

View File

@ -2063,6 +2063,14 @@ communication mechanism.
.def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)")
.def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)")
.def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)")
.def(
"split_group",
&::c10d::ProcessGroup::splitGroup,
py::arg("ranks"),
py::arg("timeout") = std::nullopt,
py::arg("opts") = std::nullopt,
py::arg("groupDesc") = std::nullopt,
py::call_guard<py::gil_scoped_release>())
.def(
"abort",
&::c10d::ProcessGroup::abort,