[c10d] allow sub group to be eagerly inited even if default one is not (#138665)

Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

Resolves https://github.com/pytorch/pytorch/issues/137018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138665
Approved by: https://github.com/kwen2501
ghstack dependencies: #138781
This commit is contained in:
Shuqiang Zhang
2024-10-24 13:49:28 -07:00
committed by PyTorch MergeBot
parent 277b32c930
commit 4c91481656
7 changed files with 64 additions and 1 deletions

View File

@ -852,6 +852,30 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
self.assertEqual(tensor, original_tensor)
dist.destroy_process_group()
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_comm_eager_init_subgroup(self):
# Test `ncclCommSplit` for smaller subgroups of the world when
# we've passed a specific device_id to init_process_group.
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank}")
# default PG comm is not initialized yet
pg = self._create_process_group_nccl(store, self.opts())
backend = pg._get_backend(torch.device(device))
self.assertEqual(backend._is_initialized(), False)
tensor = torch.full((1,), self.rank).cuda(device)
new_group = c10d.new_group([0, 1], device_id=device)
self.assertEqual(backend.comm_split_count(), 0)
new_backend = new_group._get_backend(torch.device(device))
self.assertEqual(new_backend._is_initialized(), True)
dist.broadcast(tensor, 0, group=new_group)
self.assertEqual(new_backend.comm_split_count(), 0)
self.assertEqual(backend._is_initialized(), False)
torch.cuda.synchronize()
dist.destroy_process_group()
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_comm_split_group(self):

View File

@ -578,6 +578,7 @@ class ProcessGroupNCCL(Backend):
def comm_split_count(self) -> int: ...
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
def abort(self) -> None: ...
def _is_initialized(self) -> bool: ...
@property
def uid(self) -> int: ...
@property

View File

@ -445,6 +445,11 @@ class NCCLComm {
#endif
}
bool isInitialized() const {
std::unique_lock<std::mutex> lock(mutex_);
return initialized_;
}
bool isAborted() const {
std::unique_lock<std::mutex> lock(mutex_);
return aborted_;

View File

@ -1076,6 +1076,21 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
#endif
}
bool ProcessGroupNCCL::isInitialized() {
if (devNCCLCommMap_.empty()) {
return false;
}
std::lock_guard<std::mutex> lock(mutex_);
bool initialized = true;
for (const auto& [_, comm] : devNCCLCommMap_) {
if (!comm->isInitialized()) {
initialized = false;
break;
}
}
return initialized;
}
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> ProcessGroupNCCL::
initIntraNodeComm() {
using IntraNodeComm = intra_node_comm::IntraNodeComm;

View File

@ -716,6 +716,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void performNocolorSplit(at::Device device);
// If all comms on this PG are fully initialized, return true.
bool isInitialized();
// This method adds a temporary extension for the timeout period,
// applying to all collectives between the calling of this API and
// the completion of the first collective on the GPU. While this feature

View File

@ -2772,6 +2772,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def(
"abort",
&::c10d::ProcessGroupNCCL::abort,
py::call_guard<py::gil_scoped_release>())
.def(
"_is_initialized",
&::c10d::ProcessGroupNCCL::isInitialized,
py::call_guard<py::gil_scoped_release>());
module.def(

View File

@ -4722,6 +4722,7 @@ def new_group(
pg_options=None,
use_local_synchronization=False,
group_desc=None,
device_id: Optional[torch.device] = None,
):
"""
Create a new distributed group.
@ -4774,6 +4775,9 @@ def new_group(
in that non-member ranks don't need to call into API and don't
join the barrier.
group_desc (str, optional): a string to describe the process group.
device_id (torch.device, optional): a single, specific device
to "bind" this process to, The `new_group` call will try to initialize
a communication backend immediately for the device if this field is given.
Returns:
A handle of distributed group that can be given to collective calls or
@ -4797,6 +4801,7 @@ def new_group(
None,
use_local_synchronization=use_local_synchronization,
group_desc=group_desc,
device_id=device_id,
)
@ -4808,6 +4813,7 @@ def _new_group_with_tag(
pg_tag=None,
use_local_synchronization=False,
group_desc=None,
device_id: Optional[torch.device] = None,
):
"""
Variant of ``new_group`` that exposes tag creation.
@ -4818,7 +4824,12 @@ def _new_group_with_tag(
global _world
default_pg = _get_default_group()
device_id = default_pg.bound_device_id
if device_id is None:
device_id = default_pg.bound_device_id
elif default_pg.bound_device_id is not None:
assert (
device_id == default_pg.bound_device_id
), "Mismatched bound device between new pg and the default pg."
default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()