[c10d] init_process_group supports index-only device id (#156214)

Before:
```
acc = torch.accelerator.current_accelerator()
if acc:
  local_idx = ...
  dist.init_process_group(
    device_id=torch.device(acc.type, local_idx)
  )
```
After:
```
dist.init_process_group(device_id=local_idx)
```

That is, `init_process_group` checks `torch.accelerator.current_accelerator()` internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156214
Approved by: https://github.com/guangyey, https://github.com/albanD
This commit is contained in:
Ke Wen
2025-06-20 16:59:19 -07:00
committed by PyTorch MergeBot
parent fbbab794ef
commit 0f0c010714
3 changed files with 51 additions and 65 deletions

View File

@ -1549,7 +1549,7 @@ def init_process_group(
store: Optional[Store] = None,
group_name: str = "",
pg_options: Optional[Any] = None,
device_id: Optional[torch.device] = None,
device_id: Optional[Union[torch.device, int]] = None,
) -> None:
"""
Initialize the default distributed process group.
@ -1612,15 +1612,16 @@ def init_process_group(
the nccl backend can pick up high priority cuda streams when
there're compute kernels waiting. For other availble options to config nccl,
See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
device_id (torch.device, optional): a single, specific device
to "bind" this process to, allowing for backend-specific
device_id (torch.device | int, optional): a single, specific device
this process will work on, allowing for backend-specific
optimizations. Currently this has two effects, only under
NCCL: the communicator is immediately formed (calling
``ncclCommInit*`` immediately rather than the normal lazy
call) and sub-groups will use ``ncclCommSplit`` when
possible to avoid unnecessary overhead of group creation. If you
want to know NCCL initialization error early, you can also use this
field.
field. If an `int` is provided, the API assumes that the accelerator
type at compile time will be used.
.. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
on a system that supports MPI.
@ -1665,6 +1666,39 @@ def init_process_group(
elif init_method is None:
init_method = "env://"
# Get the compile-time accelerator type.
# None indicates no accelerator support.
acc = torch.accelerator.current_accelerator()
# Auto complete device id
if isinstance(device_id, int):
if acc is None:
raise ValueError(
"device_id is an int, but no accelerator support is found from the current compilation. "
"Please use a different compiled version that supports your accelerator."
)
device_id = torch.device(acc.type, device_id)
# Sanity check device_id
if device_id is not None and device_id.type != "cpu":
# Type
if acc is None or device_id.type != acc.type:
raise ValueError(
f"device_id {device_id} does not match the current compilation's accelerator support: {acc}. "
"Please use a different compiled version that supports your accelerator."
)
# Index
if device_id.index is None:
raise ValueError("Please use a device_id with index.")
# Range
if device_id.index >= torch.accelerator.device_count():
raise ValueError(
f"device_id {device_id} is out of range. Please use a device index less than "
f"the number of accelerators available: {torch.accelerator.device_count()}."
)
logger.info("Using device: %s", device_id)
# If user did not provide a backend string but provided a device id, e.g.
# >>> init_process_group(device_id=device)
# we try to figure out the backend name based on the device type.