mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)
We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG. Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options" We need to make changes to the test to make it aligned with the change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132931 Approved by: https://github.com/H-Huang
This commit is contained in:
@ -1815,6 +1815,7 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
|
||||
|
||||
def test_init_process_group_for_all_backends(self):
|
||||
for backend in dist.Backend.backend_list:
|
||||
excepted_backend = backend
|
||||
# skip if the backend is not available on the system
|
||||
if backend == dist.Backend.UNDEFINED:
|
||||
continue
|
||||
@ -1830,6 +1831,11 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
|
||||
elif backend == dist.Backend.UCC:
|
||||
if not dist.is_ucc_available():
|
||||
continue
|
||||
# Multi-threaded PG is defined as a pure python class.
|
||||
# Its pg.name() does not going through Pybind, so its backend name
|
||||
# is still "threaded" instead of "custom".
|
||||
elif backend != "threaded":
|
||||
excepted_backend = "custom"
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
store = dist.FileStore(f.name, self.world_size)
|
||||
@ -1842,7 +1848,7 @@ class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase):
|
||||
pg = c10d._get_default_group()
|
||||
self.assertEqual(pg.rank(), self.rank)
|
||||
self.assertEqual(pg.size(), self.world_size)
|
||||
self.assertEqual(pg.name(), str(backend))
|
||||
self.assertEqual(pg.name(), str(excepted_backend))
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
@ -232,7 +232,8 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
|
||||
mesh_tensor = torch.arange(4).reshape(2, 2)
|
||||
mesh = DeviceMesh(device_type, mesh_tensor)
|
||||
self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake")
|
||||
# Fake pg only have BackendType as BackendType::CUSTOM.
|
||||
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")
|
||||
|
||||
|
||||
class DeviceMeshTestNDim(DTensorTestBase):
|
||||
|
@ -296,15 +296,6 @@ class Backend:
|
||||
def _set_default_timeout(self, timeout: timedelta) -> None: ...
|
||||
|
||||
class ProcessGroup:
|
||||
class Options:
|
||||
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
|
||||
@property
|
||||
def backend(self) -> str: ...
|
||||
@property
|
||||
def _timeout(self) -> timedelta: ...
|
||||
@_timeout.setter
|
||||
def _timeout(self, val: timedelta) -> None: ...
|
||||
|
||||
class BackendType(Enum):
|
||||
UNDEFINED = ...
|
||||
GLOO = ...
|
||||
@ -318,7 +309,6 @@ class ProcessGroup:
|
||||
store: Store,
|
||||
rank: int,
|
||||
size: int,
|
||||
options: Options,
|
||||
) -> None: ...
|
||||
def rank(self) -> int: ...
|
||||
def size(self) -> int: ...
|
||||
@ -508,6 +498,7 @@ class ProcessGroup:
|
||||
@property
|
||||
def _device_types(self) -> list[torch.device]: ...
|
||||
def _get_backend(self, device: torch.device) -> Backend: ...
|
||||
def _set_default_backend(self, backend_type: BackendType) -> None: ...
|
||||
def _register_backend(
|
||||
self,
|
||||
device: torch.device,
|
||||
@ -532,7 +523,7 @@ class ProcessGroup:
|
||||
class ProcessGroupGloo(Backend):
|
||||
class Device: ...
|
||||
|
||||
class Options(ProcessGroup.Options):
|
||||
class Options(Backend.Options):
|
||||
devices: list[ProcessGroupGloo.Device]
|
||||
threads: int
|
||||
|
||||
@ -562,7 +553,7 @@ class ProcessGroupNCCL(Backend):
|
||||
min_ctas: int
|
||||
max_ctas: int
|
||||
|
||||
class Options(ProcessGroup.Options):
|
||||
class Options(Backend.Options):
|
||||
config: ProcessGroupNCCL.NCCLConfig
|
||||
is_high_priority_stream: bool
|
||||
split_from: ProcessGroupNCCL
|
||||
|
@ -14,22 +14,6 @@
|
||||
|
||||
namespace c10d {
|
||||
|
||||
static ProcessGroup::BackendType strToBackendType(std::string_view backend) {
|
||||
if (backend == "undefined") {
|
||||
return ProcessGroup::BackendType::UNDEFINED;
|
||||
} else if (backend == "gloo") {
|
||||
return ProcessGroup::BackendType::GLOO;
|
||||
} else if (backend == "nccl") {
|
||||
return ProcessGroup::BackendType::NCCL;
|
||||
} else if (backend == "ucc") {
|
||||
return ProcessGroup::BackendType::UCC;
|
||||
} else if (backend == "mpi") {
|
||||
return ProcessGroup::BackendType::MPI;
|
||||
} else {
|
||||
return ProcessGroup::BackendType::CUSTOM;
|
||||
}
|
||||
}
|
||||
|
||||
std::string opTypeToString(OpType opType) {
|
||||
switch (opType) {
|
||||
case OpType::BROADCAST:
|
||||
@ -119,13 +103,11 @@ c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
|
||||
ProcessGroup::ProcessGroup(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
int rank,
|
||||
int size,
|
||||
c10::intrusive_ptr<Options> options)
|
||||
int size)
|
||||
: store_(store),
|
||||
rank_(rank),
|
||||
size_(size),
|
||||
options_(std::move(options)),
|
||||
backendType_(strToBackendType(options_->backend)),
|
||||
backendType_(BackendType::UNDEFINED),
|
||||
dist_debug_level_(debug_level()) {
|
||||
C10_LOG_API_USAGE_ONCE("c10d.process_group");
|
||||
}
|
||||
|
@ -45,24 +45,6 @@ namespace c10d {
|
||||
//
|
||||
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
public:
|
||||
// ProcessGroup Options is a base struct that defines the basic options
|
||||
// when constructing a ProcessGroup. Each ProcessGroup subclass should
|
||||
// extend this struct and define its options if it wants to provide more
|
||||
// config options (beyond basic ones defined here) to end user.
|
||||
struct TORCH_API Options : torch::CustomClassHolder {
|
||||
explicit Options(
|
||||
std::string backend,
|
||||
std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
|
||||
: timeout(timeout), backend(std::move(backend)) {}
|
||||
~Options() override = default;
|
||||
|
||||
std::chrono::milliseconds timeout;
|
||||
|
||||
// backend name
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const std::string backend;
|
||||
};
|
||||
|
||||
enum BackendType : uint8_t {
|
||||
UNDEFINED = 0,
|
||||
GLOO = 1,
|
||||
@ -72,6 +54,23 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
CUSTOM = 5,
|
||||
};
|
||||
|
||||
static std::string backendTypeToString(BackendType type) {
|
||||
switch (type) {
|
||||
case BackendType::GLOO:
|
||||
return "gloo";
|
||||
case BackendType::NCCL:
|
||||
return "nccl";
|
||||
case BackendType::UCC:
|
||||
return "ucc";
|
||||
case BackendType::MPI:
|
||||
return "mpi";
|
||||
case BackendType::UNDEFINED:
|
||||
return "undefined";
|
||||
default:
|
||||
return "custom";
|
||||
}
|
||||
};
|
||||
|
||||
// Not used, set for backwards compatibility and only used for TypeDef in
|
||||
// Ops.cpp
|
||||
explicit ProcessGroup(int rank, int size);
|
||||
@ -79,8 +78,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
explicit ProcessGroup(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
int rank,
|
||||
int size,
|
||||
c10::intrusive_ptr<Options> options);
|
||||
int size);
|
||||
~ProcessGroup() override;
|
||||
|
||||
int getRank() const {
|
||||
@ -103,7 +101,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
}
|
||||
|
||||
virtual const std::string getBackendName() const {
|
||||
return options_->backend;
|
||||
return backendTypeToString(backendType_);
|
||||
};
|
||||
|
||||
BackendType getBackendType() const {
|
||||
@ -609,10 +607,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
opts.timeout.count());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Options> getOptions() {
|
||||
return options_;
|
||||
}
|
||||
|
||||
bool hasBackends() {
|
||||
return !deviceTypeToBackendType_.empty();
|
||||
}
|
||||
@ -653,6 +647,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
return backendTypeToBackend_.at(backendType_);
|
||||
}
|
||||
|
||||
void setDefaultBackend(const BackendType& backendType) {
|
||||
backendType_ = backendType;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);
|
||||
|
||||
c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
|
||||
@ -725,9 +723,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const int size_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const c10::intrusive_ptr<Options> options_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const BackendType backendType_;
|
||||
BackendType backendType_;
|
||||
std::string pg_desc_;
|
||||
|
||||
// Debug level setting. It is parsed once when ProcessGroup is constructed and
|
||||
|
@ -1814,8 +1814,7 @@ communication mechanism.
|
||||
py::init<
|
||||
const c10::intrusive_ptr<::c10d::Store>&,
|
||||
int,
|
||||
int,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(),
|
||||
int>(),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("rank", &::c10d::ProcessGroup::getRank)
|
||||
.def("size", &::c10d::ProcessGroup::getSize)
|
||||
@ -1825,7 +1824,6 @@ communication mechanism.
|
||||
"_backend_id",
|
||||
&::c10d::ProcessGroup::getBackendID,
|
||||
py::arg("backend_type"))
|
||||
.def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
|
||||
.def(
|
||||
"broadcast",
|
||||
&::c10d::ProcessGroup::broadcast,
|
||||
@ -2135,6 +2133,14 @@ communication mechanism.
|
||||
},
|
||||
py::arg("device"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_set_default_backend",
|
||||
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
|
||||
const ::c10d::ProcessGroup::BackendType& backendType) {
|
||||
return self->setDefaultBackend(backendType);
|
||||
},
|
||||
py::arg("backend_type"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_register_on_completion_hook",
|
||||
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
|
||||
@ -2237,27 +2243,6 @@ Arguments:
|
||||
.value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
|
||||
.export_values();
|
||||
|
||||
// base ProcessGroup::Options binding
|
||||
auto processGroupOptions =
|
||||
intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
|
||||
processGroup,
|
||||
"Options",
|
||||
R"(
|
||||
Base class for all processes group options implementations, such as the nccl
|
||||
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
)")
|
||||
.def(
|
||||
py::init([](const std::string& backend,
|
||||
const std::chrono::milliseconds& timeout) {
|
||||
return c10::make_intrusive<::c10d::ProcessGroup::Options>(
|
||||
backend, timeout);
|
||||
}),
|
||||
py::arg("backend"),
|
||||
py::arg("timeout") = kProcessGroupDefaultTimeout,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
|
||||
.def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
|
||||
|
||||
// TODO: The collection definitions handles direct instantiation of
|
||||
// ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported
|
||||
// and should be removed once all tests are transitioned
|
||||
@ -2556,6 +2541,29 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
&::c10d::Backend::endCoalescing,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// base Backend::Options binding
|
||||
// TODO: Maybe we can consider how to merge this with
|
||||
// `DistributedBackendOptions`.
|
||||
auto backendOptions =
|
||||
intrusive_ptr_class_<::c10d::Backend::Options>(
|
||||
backend,
|
||||
"Options",
|
||||
R"(
|
||||
Base class for all backend options implementations, such as the nccl
|
||||
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
)")
|
||||
.def(
|
||||
py::init([](const std::string& backend,
|
||||
const std::chrono::milliseconds& timeout) {
|
||||
return c10::make_intrusive<::c10d::Backend::Options>(
|
||||
backend, timeout);
|
||||
}),
|
||||
py::arg("backend"),
|
||||
py::arg("timeout") = kProcessGroupDefaultTimeout,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_readonly("backend", &::c10d::Backend::Options::backend)
|
||||
.def_readwrite("_timeout", &::c10d::Backend::Options::timeout);
|
||||
|
||||
#ifdef USE_C10D_GLOO
|
||||
static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
|
||||
|
||||
@ -2567,7 +2575,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
|
||||
|
||||
intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
|
||||
processGroupGloo, "_Options", processGroupOptions)
|
||||
processGroupGloo, "_Options", backendOptions)
|
||||
.def(py::init<>())
|
||||
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
|
||||
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
|
||||
@ -2794,7 +2802,7 @@ for details.
|
||||
intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
|
||||
processGroupNCCL,
|
||||
"Options",
|
||||
processGroupOptions,
|
||||
backendOptions,
|
||||
R"(
|
||||
ProcessGroup options for the NCCL backend
|
||||
|
||||
|
@ -36,6 +36,7 @@ if not is_available():
|
||||
|
||||
|
||||
else:
|
||||
from torch._C._distributed_c10d import Backend as C10dBackend
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_find_pg_by_ranks_and_tag,
|
||||
_get_default_group,
|
||||
@ -66,7 +67,7 @@ else:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
self.mesh_dim_group_options: Dict[
|
||||
int, Tuple[str, Optional[ProcessGroup.Options]]
|
||||
int, Tuple[str, Optional[C10dBackend.Options]]
|
||||
] = {}
|
||||
self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
|
||||
# Record flatten mesh name to its mesh dim index in root mesh.
|
||||
@ -279,7 +280,7 @@ else:
|
||||
self,
|
||||
dim: int,
|
||||
backend: str,
|
||||
pg_options: Optional[ProcessGroup.Options] = None,
|
||||
pg_options: Optional[C10dBackend.Options] = None,
|
||||
) -> None:
|
||||
self.mesh_dim_group_options[dim] = (backend, pg_options)
|
||||
|
||||
|
@ -266,6 +266,7 @@ class Backend(str):
|
||||
GLOO: ProcessGroup.BackendType.GLOO,
|
||||
NCCL: ProcessGroup.BackendType.NCCL,
|
||||
UCC: ProcessGroup.BackendType.UCC,
|
||||
MPI: ProcessGroup.BackendType.MPI,
|
||||
}
|
||||
|
||||
def __new__(cls, name: str):
|
||||
@ -1531,7 +1532,7 @@ def init_process_group(
|
||||
backend,
|
||||
store,
|
||||
group_name,
|
||||
pg_options=pg_options,
|
||||
backend_options=pg_options,
|
||||
timeout=timeout,
|
||||
device_id=device_id,
|
||||
group_desc="default_pg",
|
||||
@ -1628,7 +1629,7 @@ def _new_process_group_helper(
|
||||
backend,
|
||||
store,
|
||||
group_name,
|
||||
pg_options=None,
|
||||
backend_options=None,
|
||||
timeout=None,
|
||||
pg_tag=None,
|
||||
device_id=None,
|
||||
@ -1704,11 +1705,17 @@ def _new_process_group_helper(
|
||||
return GroupMember.NON_GROUP_MEMBER, None
|
||||
|
||||
prefix_store = PrefixStore(f"{group_name}/", store)
|
||||
base_pg_options = ProcessGroup.Options(backend=str(backend))
|
||||
base_pg_options._timeout = timeout
|
||||
# The backend for PG will be set later based on what's inside BackendConfig
|
||||
# and timeout are set in each backend's option.
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store, group_rank, group_size, base_pg_options
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
# Set the default backend when only single backend is passed in.
|
||||
if "," not in str(backend) and ":" not in str(backend):
|
||||
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
|
||||
pg._set_default_backend(Backend.backend_type_map[backend])
|
||||
if device_id:
|
||||
pg.bound_device_id = device_id
|
||||
backend_config = BackendConfig(backend)
|
||||
@ -1735,8 +1742,8 @@ def _new_process_group_helper(
|
||||
backend_prefix_store,
|
||||
backend_class.rank(),
|
||||
backend_class.size(),
|
||||
base_pg_options,
|
||||
)
|
||||
pg._set_default_backend(backend_type)
|
||||
elif backend_str == Backend.GLOO:
|
||||
# TODO: remove this check after lazy initialization is supported
|
||||
# if pg_options is not None:
|
||||
@ -1748,28 +1755,30 @@ def _new_process_group_helper(
|
||||
elif backend_str == Backend.NCCL:
|
||||
if not is_nccl_available():
|
||||
raise RuntimeError("Distributed package doesn't have NCCL built in")
|
||||
if pg_options is not None:
|
||||
if backend_options is not None:
|
||||
assert isinstance(
|
||||
pg_options, ProcessGroupNCCL.Options
|
||||
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
|
||||
if pg_options._timeout != timeout:
|
||||
backend_options, ProcessGroupNCCL.Options
|
||||
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
if backend_options._timeout != timeout:
|
||||
warnings.warn(
|
||||
"pg_options._timeout was specified, "
|
||||
"backend_options._timeout was specified, "
|
||||
"but timeout kwarg has a default value that will always override it. "
|
||||
)
|
||||
else:
|
||||
# default pg_options for NCCL
|
||||
pg_options = ProcessGroupNCCL.Options()
|
||||
pg_options.is_high_priority_stream = False
|
||||
pg_options._timeout = timeout
|
||||
# default backend_options for NCCL
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options.is_high_priority_stream = False
|
||||
backend_options._timeout = timeout
|
||||
|
||||
if split_from:
|
||||
pg_options.split_from = split_from
|
||||
pg_options.split_color = _process_group_color(global_ranks_in_group)
|
||||
pg_options.global_ranks_in_group = global_ranks_in_group
|
||||
pg_options.group_name = group_name
|
||||
backend_options.split_from = split_from
|
||||
backend_options.split_color = _process_group_color(
|
||||
global_ranks_in_group
|
||||
)
|
||||
backend_options.global_ranks_in_group = global_ranks_in_group
|
||||
backend_options.group_name = group_name
|
||||
backend_class = ProcessGroupNCCL(
|
||||
backend_prefix_store, group_rank, group_size, pg_options
|
||||
backend_prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
elif backend_str == Backend.UCC and is_ucc_available():
|
||||
@ -1804,7 +1813,7 @@ def _new_process_group_helper(
|
||||
dist_backend_opts.group_id = group_name
|
||||
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
|
||||
|
||||
backend_class = creator_fn(dist_backend_opts, pg_options)
|
||||
backend_class = creator_fn(dist_backend_opts, backend_options)
|
||||
|
||||
# Set sequence numbers for gloo and nccl backends.
|
||||
if backend_str == Backend.GLOO:
|
||||
@ -4439,12 +4448,15 @@ def split_group(
|
||||
global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
|
||||
|
||||
prefix_store = PrefixStore(f"{group_name}/", default_store)
|
||||
base_pg_options = ProcessGroup.Options(backend=str(backend))
|
||||
base_pg_options._timeout = timeout
|
||||
# We register the backend after initializing and timeout is set in pg_options.
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store, group_rank, len(my_group), base_pg_options
|
||||
prefix_store,
|
||||
group_rank,
|
||||
len(my_group),
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
pg.bound_device_id = device_id
|
||||
pg._set_default_backend(backend_type)
|
||||
|
||||
pg_options._timeout = timeout
|
||||
pg_options.split_from = parent_backend
|
||||
@ -4454,7 +4466,6 @@ def split_group(
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, len(my_group), pg_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(torch.device("cuda"), backend_type, backend_class)
|
||||
@ -4577,7 +4588,7 @@ def _new_group_with_tag(
|
||||
ranks=None,
|
||||
timeout=None,
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
backend_options=None,
|
||||
pg_tag=None,
|
||||
use_local_synchronization=False,
|
||||
group_desc=None,
|
||||
@ -4652,7 +4663,7 @@ def _new_group_with_tag(
|
||||
backend,
|
||||
default_store,
|
||||
group_name,
|
||||
pg_options=pg_options,
|
||||
backend_options=backend_options,
|
||||
timeout=timeout,
|
||||
pg_tag=pg_tag,
|
||||
device_id=device_id,
|
||||
|
Reference in New Issue
Block a user