[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)

We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"

We need to make changes to the test to make it aligned with the change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132931
Approved by: https://github.com/H-Huang
This commit is contained in:
fduwjj
2024-08-28 15:35:16 -07:00
committed by PyTorch MergeBot
parent 8b4c487581
commit 65864d0134
8 changed files with 113 additions and 117 deletions

View File

@ -1815,6 +1815,7 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
def test_init_process_group_for_all_backends(self):
for backend in dist.Backend.backend_list:
excepted_backend = backend
# skip if the backend is not available on the system
if backend == dist.Backend.UNDEFINED:
continue
@ -1830,6 +1831,11 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
elif backend == dist.Backend.UCC:
if not dist.is_ucc_available():
continue
# Multi-threaded PG is defined as a pure python class.
# Its pg.name() does not going through Pybind, so its backend name
# is still "threaded" instead of "custom".
elif backend != "threaded":
excepted_backend = "custom"
with tempfile.NamedTemporaryFile(delete=False) as f:
store = dist.FileStore(f.name, self.world_size)
@ -1842,7 +1848,7 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
pg = c10d._get_default_group()
self.assertEqual(pg.rank(), self.rank)
self.assertEqual(pg.size(), self.world_size)
self.assertEqual(pg.name(), str(backend))
self.assertEqual(pg.name(), str(excepted_backend))
dist.destroy_process_group()

View File

@ -232,7 +232,8 @@ class DeviceMeshTest(DTensorTestBase):
mesh_tensor = torch.arange(4).reshape(2, 2)
mesh = DeviceMesh(device_type, mesh_tensor)
self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake")
# Fake pg only have BackendType as BackendType::CUSTOM.
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")
class DeviceMeshTestNDim(DTensorTestBase):

View File

@ -296,15 +296,6 @@ class Backend:
def _set_default_timeout(self, timeout: timedelta) -> None: ...
class ProcessGroup:
class Options:
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
@property
def backend(self) -> str: ...
@property
def _timeout(self) -> timedelta: ...
@_timeout.setter
def _timeout(self, val: timedelta) -> None: ...
class BackendType(Enum):
UNDEFINED = ...
GLOO = ...
@ -318,7 +309,6 @@ class ProcessGroup:
store: Store,
rank: int,
size: int,
options: Options,
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
@ -508,6 +498,7 @@ class ProcessGroup:
@property
def _device_types(self) -> list[torch.device]: ...
def _get_backend(self, device: torch.device) -> Backend: ...
def _set_default_backend(self, backend_type: BackendType) -> None: ...
def _register_backend(
self,
device: torch.device,
@ -532,7 +523,7 @@ class ProcessGroup:
class ProcessGroupGloo(Backend):
class Device: ...
class Options(ProcessGroup.Options):
class Options(Backend.Options):
devices: list[ProcessGroupGloo.Device]
threads: int
@ -562,7 +553,7 @@ class ProcessGroupNCCL(Backend):
min_ctas: int
max_ctas: int
class Options(ProcessGroup.Options):
class Options(Backend.Options):
config: ProcessGroupNCCL.NCCLConfig
is_high_priority_stream: bool
split_from: ProcessGroupNCCL

View File

@ -14,22 +14,6 @@
namespace c10d {
static ProcessGroup::BackendType strToBackendType(std::string_view backend) {
if (backend == "undefined") {
return ProcessGroup::BackendType::UNDEFINED;
} else if (backend == "gloo") {
return ProcessGroup::BackendType::GLOO;
} else if (backend == "nccl") {
return ProcessGroup::BackendType::NCCL;
} else if (backend == "ucc") {
return ProcessGroup::BackendType::UCC;
} else if (backend == "mpi") {
return ProcessGroup::BackendType::MPI;
} else {
return ProcessGroup::BackendType::CUSTOM;
}
}
std::string opTypeToString(OpType opType) {
switch (opType) {
case OpType::BROADCAST:
@ -119,13 +103,11 @@ c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
ProcessGroup::ProcessGroup(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options)
int size)
: store_(store),
rank_(rank),
size_(size),
options_(std::move(options)),
backendType_(strToBackendType(options_->backend)),
backendType_(BackendType::UNDEFINED),
dist_debug_level_(debug_level()) {
C10_LOG_API_USAGE_ONCE("c10d.process_group");
}

View File

@ -45,24 +45,6 @@ namespace c10d {
//
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
public:
// ProcessGroup Options is a base struct that defines the basic options
// when constructing a ProcessGroup. Each ProcessGroup subclass should
// extend this struct and define its options if it wants to provide more
// config options (beyond basic ones defined here) to end user.
struct TORCH_API Options : torch::CustomClassHolder {
explicit Options(
std::string backend,
std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
: timeout(timeout), backend(std::move(backend)) {}
~Options() override = default;
std::chrono::milliseconds timeout;
// backend name
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string backend;
};
enum BackendType : uint8_t {
UNDEFINED = 0,
GLOO = 1,
@ -72,6 +54,23 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
CUSTOM = 5,
};
static std::string backendTypeToString(BackendType type) {
switch (type) {
case BackendType::GLOO:
return "gloo";
case BackendType::NCCL:
return "nccl";
case BackendType::UCC:
return "ucc";
case BackendType::MPI:
return "mpi";
case BackendType::UNDEFINED:
return "undefined";
default:
return "custom";
}
};
// Not used, set for backwards compatibility and only used for TypeDef in
// Ops.cpp
explicit ProcessGroup(int rank, int size);
@ -79,8 +78,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
explicit ProcessGroup(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options);
int size);
~ProcessGroup() override;
int getRank() const {
@ -103,7 +101,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
}
virtual const std::string getBackendName() const {
return options_->backend;
return backendTypeToString(backendType_);
};
BackendType getBackendType() const {
@ -609,10 +607,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
opts.timeout.count());
}
c10::intrusive_ptr<Options> getOptions() {
return options_;
}
bool hasBackends() {
return !deviceTypeToBackendType_.empty();
}
@ -653,6 +647,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return backendTypeToBackend_.at(backendType_);
}
void setDefaultBackend(const BackendType& backendType) {
backendType_ = backendType;
}
c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);
c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
@ -725,9 +723,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int size_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const c10::intrusive_ptr<Options> options_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const BackendType backendType_;
BackendType backendType_;
std::string pg_desc_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and

View File

@ -1814,8 +1814,7 @@ communication mechanism.
py::init<
const c10::intrusive_ptr<::c10d::Store>&,
int,
int,
c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(),
int>(),
py::call_guard<py::gil_scoped_release>())
.def("rank", &::c10d::ProcessGroup::getRank)
.def("size", &::c10d::ProcessGroup::getSize)
@ -1825,7 +1824,6 @@ communication mechanism.
"_backend_id",
&::c10d::ProcessGroup::getBackendID,
py::arg("backend_type"))
.def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
.def(
"broadcast",
&::c10d::ProcessGroup::broadcast,
@ -2135,6 +2133,14 @@ communication mechanism.
},
py::arg("device"),
py::call_guard<py::gil_scoped_release>())
.def(
"_set_default_backend",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
const ::c10d::ProcessGroup::BackendType& backendType) {
return self->setDefaultBackend(backendType);
},
py::arg("backend_type"),
py::call_guard<py::gil_scoped_release>())
.def(
"_register_on_completion_hook",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
@ -2237,27 +2243,6 @@ Arguments:
.value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
.export_values();
// base ProcessGroup::Options binding
auto processGroupOptions =
intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
processGroup,
"Options",
R"(
Base class for all processes group options implementations, such as the nccl
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
)")
.def(
py::init([](const std::string& backend,
const std::chrono::milliseconds& timeout) {
return c10::make_intrusive<::c10d::ProcessGroup::Options>(
backend, timeout);
}),
py::arg("backend"),
py::arg("timeout") = kProcessGroupDefaultTimeout,
py::call_guard<py::gil_scoped_release>())
.def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
.def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
// TODO: The collection definitions handles direct instantiation of
// ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported
// and should be removed once all tests are transitioned
@ -2556,6 +2541,29 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
&::c10d::Backend::endCoalescing,
py::call_guard<py::gil_scoped_release>());
// base Backend::Options binding
// TODO: Maybe we can consider how to merge this with
// `DistributedBackendOptions`.
auto backendOptions =
intrusive_ptr_class_<::c10d::Backend::Options>(
backend,
"Options",
R"(
Base class for all backend options implementations, such as the nccl
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
)")
.def(
py::init([](const std::string& backend,
const std::chrono::milliseconds& timeout) {
return c10::make_intrusive<::c10d::Backend::Options>(
backend, timeout);
}),
py::arg("backend"),
py::arg("timeout") = kProcessGroupDefaultTimeout,
py::call_guard<py::gil_scoped_release>())
.def_readonly("backend", &::c10d::Backend::Options::backend)
.def_readwrite("_timeout", &::c10d::Backend::Options::timeout);
#ifdef USE_C10D_GLOO
static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
@ -2567,7 +2575,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
processGroupGloo, "_Options", processGroupOptions)
processGroupGloo, "_Options", backendOptions)
.def(py::init<>())
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
@ -2794,7 +2802,7 @@ for details.
intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
processGroupNCCL,
"Options",
processGroupOptions,
backendOptions,
R"(
ProcessGroup options for the NCCL backend

View File

@ -36,6 +36,7 @@ if not is_available():
else:
from torch._C._distributed_c10d import Backend as C10dBackend
from torch.distributed.distributed_c10d import (
_find_pg_by_ranks_and_tag,
_get_default_group,
@ -66,7 +67,7 @@ else:
self.mesh_stack: List[DeviceMesh] = []
self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {}
self.mesh_dim_group_options: Dict[
int, Tuple[str, Optional[ProcessGroup.Options]]
int, Tuple[str, Optional[C10dBackend.Options]]
] = {}
self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
# Record flatten mesh name to its mesh dim index in root mesh.
@ -279,7 +280,7 @@ else:
self,
dim: int,
backend: str,
pg_options: Optional[ProcessGroup.Options] = None,
pg_options: Optional[C10dBackend.Options] = None,
) -> None:
self.mesh_dim_group_options[dim] = (backend, pg_options)

View File

@ -266,6 +266,7 @@ class Backend(str):
GLOO: ProcessGroup.BackendType.GLOO,
NCCL: ProcessGroup.BackendType.NCCL,
UCC: ProcessGroup.BackendType.UCC,
MPI: ProcessGroup.BackendType.MPI,
}
def __new__(cls, name: str):
@ -1531,7 +1532,7 @@ def init_process_group(
backend,
store,
group_name,
pg_options=pg_options,
backend_options=pg_options,
timeout=timeout,
device_id=device_id,
group_desc="default_pg",
@ -1628,7 +1629,7 @@ def _new_process_group_helper(
backend,
store,
group_name,
pg_options=None,
backend_options=None,
timeout=None,
pg_tag=None,
device_id=None,
@ -1704,11 +1705,17 @@ def _new_process_group_helper(
return GroupMember.NON_GROUP_MEMBER, None
prefix_store = PrefixStore(f"{group_name}/", store)
base_pg_options = ProcessGroup.Options(backend=str(backend))
base_pg_options._timeout = timeout
# The backend for PG will be set later based on what's inside BackendConfig
# and timeout are set in each backend's option.
pg: ProcessGroup = ProcessGroup(
prefix_store, group_rank, group_size, base_pg_options
prefix_store,
group_rank,
group_size,
)
# Set the default backend when only single backend is passed in.
if "," not in str(backend) and ":" not in str(backend):
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
pg._set_default_backend(Backend.backend_type_map[backend])
if device_id:
pg.bound_device_id = device_id
backend_config = BackendConfig(backend)
@ -1735,8 +1742,8 @@ def _new_process_group_helper(
backend_prefix_store,
backend_class.rank(),
backend_class.size(),
base_pg_options,
)
pg._set_default_backend(backend_type)
elif backend_str == Backend.GLOO:
# TODO: remove this check after lazy initialization is supported
# if pg_options is not None:
@ -1748,28 +1755,30 @@ def _new_process_group_helper(
elif backend_str == Backend.NCCL:
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL built in")
if pg_options is not None:
if backend_options is not None:
assert isinstance(
pg_options, ProcessGroupNCCL.Options
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
if pg_options._timeout != timeout:
backend_options, ProcessGroupNCCL.Options
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
if backend_options._timeout != timeout:
warnings.warn(
"pg_options._timeout was specified, "
"backend_options._timeout was specified, "
"but timeout kwarg has a default value that will always override it. "
)
else:
# default pg_options for NCCL
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = False
pg_options._timeout = timeout
# default backend_options for NCCL
backend_options = ProcessGroupNCCL.Options()
backend_options.is_high_priority_stream = False
backend_options._timeout = timeout
if split_from:
pg_options.split_from = split_from
pg_options.split_color = _process_group_color(global_ranks_in_group)
pg_options.global_ranks_in_group = global_ranks_in_group
pg_options.group_name = group_name
backend_options.split_from = split_from
backend_options.split_color = _process_group_color(
global_ranks_in_group
)
backend_options.global_ranks_in_group = global_ranks_in_group
backend_options.group_name = group_name
backend_class = ProcessGroupNCCL(
backend_prefix_store, group_rank, group_size, pg_options
backend_prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
elif backend_str == Backend.UCC and is_ucc_available():
@ -1804,7 +1813,7 @@ def _new_process_group_helper(
dist_backend_opts.group_id = group_name
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
backend_class = creator_fn(dist_backend_opts, pg_options)
backend_class = creator_fn(dist_backend_opts, backend_options)
# Set sequence numbers for gloo and nccl backends.
if backend_str == Backend.GLOO:
@ -4439,12 +4448,15 @@ def split_group(
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
prefix_store = PrefixStore(f"{group_name}/", default_store)
base_pg_options = ProcessGroup.Options(backend=str(backend))
base_pg_options._timeout = timeout
# We register the backend after initializing and timeout is set in pg_options.
pg: ProcessGroup = ProcessGroup(
prefix_store, group_rank, len(my_group), base_pg_options
prefix_store,
group_rank,
len(my_group),
)
backend_type = ProcessGroup.BackendType.NCCL
pg.bound_device_id = device_id
pg._set_default_backend(backend_type)
pg_options._timeout = timeout
pg_options.split_from = parent_backend
@ -4454,7 +4466,6 @@ def split_group(
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, len(my_group), pg_options
)
backend_type = ProcessGroup.BackendType.NCCL
backend_class._set_sequence_number_for_group()
pg._register_backend(torch.device("cuda"), backend_type, backend_class)
@ -4577,7 +4588,7 @@ def _new_group_with_tag(
ranks=None,
timeout=None,
backend=None,
pg_options=None,
backend_options=None,
pg_tag=None,
use_local_synchronization=False,
group_desc=None,
@ -4652,7 +4663,7 @@ def _new_group_with_tag(
backend,
default_store,
group_name,
pg_options=pg_options,
backend_options=backend_options,
timeout=timeout,
pg_tag=pg_tag,
device_id=device_id,