[c10d]Prototype of remote_group_merge (#158287)

Tentative implementation of merge_remote_group per the proposal here: [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158287
Approved by: https://github.com/d4l3k
ghstack dependencies: #157716
This commit is contained in:
fduwjj
2025-07-16 07:13:57 -07:00
committed by PyTorch MergeBot
parent 944a140e90
commit f58a680d09
11 changed files with 194 additions and 9 deletions

View File

@ -71,6 +71,21 @@ C10_EXPORT bool allow_inflight_collective_as_graph_input();
//
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
public:
struct TORCH_API MergeOptions : torch::CustomClassHolder {
explicit MergeOptions(
const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout,
const std::optional<std::string> group_name = std::nullopt,
const std::optional<std::string> group_desc = std::nullopt)
: timeout(timeout), group_name(group_name), group_desc(group_desc) {}
~MergeOptions() override = default;
MergeOptions(const MergeOptions&) = delete;
MergeOptions& operator=(const MergeOptions&) = delete;
std::chrono::milliseconds timeout;
std::optional<std::string> group_name;
std::optional<std::string> group_desc;
};
enum BackendType : uint8_t {
UNDEFINED = 0,
GLOO = 1,
@ -967,6 +982,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::optional<c10::intrusive_ptr<Backend::Options>> opts,
const std::optional<std::string>& groupDesc);
// This creates a new subgroup using the specified ranks.
// The current rank must be included in the list of new_ranks.
virtual c10::intrusive_ptr<ProcessGroup> mergeRemoteGroup(
const c10::intrusive_ptr<Store>& store,
const MergeOptions& opts,
const int& size);
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.