mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] allow sub group to be eagerly inited even if default one is not (#138665)
Summary: Currently, eager mode is applied either to all PGs or NONE of them. There are cases where we don't want to initialize the comms for default PG, but we still want to initialize the comms for sub PG. Now with a device_id passed to new group, we can achieve this case Test Plan: newly added UT Tags: Resolves https://github.com/pytorch/pytorch/issues/137018 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138665 Approved by: https://github.com/kwen2501 ghstack dependencies: #138781
This commit is contained in:
committed by
PyTorch MergeBot
parent
277b32c930
commit
4c91481656
@ -852,6 +852,30 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
self.assertEqual(tensor, original_tensor)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_comm_eager_init_subgroup(self):
|
||||
# Test `ncclCommSplit` for smaller subgroups of the world when
|
||||
# we've passed a specific device_id to init_process_group.
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
# default PG comm is not initialized yet
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
self.assertEqual(backend._is_initialized(), False)
|
||||
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
new_group = c10d.new_group([0, 1], device_id=device)
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
|
||||
new_backend = new_group._get_backend(torch.device(device))
|
||||
self.assertEqual(new_backend._is_initialized(), True)
|
||||
dist.broadcast(tensor, 0, group=new_group)
|
||||
self.assertEqual(new_backend.comm_split_count(), 0)
|
||||
self.assertEqual(backend._is_initialized(), False)
|
||||
torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_comm_split_group(self):
|
||||
|
@ -578,6 +578,7 @@ class ProcessGroupNCCL(Backend):
|
||||
def comm_split_count(self) -> int: ...
|
||||
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
|
||||
def abort(self) -> None: ...
|
||||
def _is_initialized(self) -> bool: ...
|
||||
@property
|
||||
def uid(self) -> int: ...
|
||||
@property
|
||||
|
@ -445,6 +445,11 @@ class NCCLComm {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool isInitialized() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return initialized_;
|
||||
}
|
||||
|
||||
bool isAborted() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return aborted_;
|
||||
|
@ -1076,6 +1076,21 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool ProcessGroupNCCL::isInitialized() {
|
||||
if (devNCCLCommMap_.empty()) {
|
||||
return false;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
bool initialized = true;
|
||||
for (const auto& [_, comm] : devNCCLCommMap_) {
|
||||
if (!comm->isInitialized()) {
|
||||
initialized = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return initialized;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
|
||||
initIntraNodeComm() {
|
||||
using IntraNodeComm = intra_node_comm::IntraNodeComm;
|
||||
|
@ -716,6 +716,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
void performNocolorSplit(at::Device device);
|
||||
|
||||
// If all comms on this PG are fully initialized, return true.
|
||||
bool isInitialized();
|
||||
|
||||
// This method adds a temporary extension for the timeout period,
|
||||
// applying to all collectives between the calling of this API and
|
||||
// the completion of the first collective on the GPU. While this feature
|
||||
|
@ -2772,6 +2772,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
.def(
|
||||
"abort",
|
||||
&::c10d::ProcessGroupNCCL::abort,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_is_initialized",
|
||||
&::c10d::ProcessGroupNCCL::isInitialized,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
module.def(
|
||||
|
@ -4722,6 +4722,7 @@ def new_group(
|
||||
pg_options=None,
|
||||
use_local_synchronization=False,
|
||||
group_desc=None,
|
||||
device_id: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
Create a new distributed group.
|
||||
@ -4774,6 +4775,9 @@ def new_group(
|
||||
in that non-member ranks don't need to call into API and don't
|
||||
join the barrier.
|
||||
group_desc (str, optional): a string to describe the process group.
|
||||
device_id (torch.device, optional): a single, specific device
|
||||
to "bind" this process to, The `new_group` call will try to initialize
|
||||
a communication backend immediately for the device if this field is given.
|
||||
|
||||
Returns:
|
||||
A handle of distributed group that can be given to collective calls or
|
||||
@ -4797,6 +4801,7 @@ def new_group(
|
||||
None,
|
||||
use_local_synchronization=use_local_synchronization,
|
||||
group_desc=group_desc,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
|
||||
@ -4808,6 +4813,7 @@ def _new_group_with_tag(
|
||||
pg_tag=None,
|
||||
use_local_synchronization=False,
|
||||
group_desc=None,
|
||||
device_id: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
Variant of ``new_group`` that exposes tag creation.
|
||||
@ -4818,7 +4824,12 @@ def _new_group_with_tag(
|
||||
global _world
|
||||
|
||||
default_pg = _get_default_group()
|
||||
device_id = default_pg.bound_device_id
|
||||
if device_id is None:
|
||||
device_id = default_pg.bound_device_id
|
||||
elif default_pg.bound_device_id is not None:
|
||||
assert (
|
||||
device_id == default_pg.bound_device_id
|
||||
), "Mismatched bound device between new pg and the default pg."
|
||||
default_backend, default_store = _world.pg_map[default_pg]
|
||||
global_rank = default_pg.rank()
|
||||
global_world_size = default_pg.size()
|
||||
|
Reference in New Issue
Block a user