Compare commits

...

1 Commits

Author SHA1 Message Date
588f15cccb set pg name based on ranks
Summary:
- in torchft we have multiple default pg's, 1 for each task group
- for flight recorder to work, each of these need to have a different name, so entries can be matched
- change the `init_process_group` api to optionally take a list of ranks. if provided, we use the hash of the ranks as the name of the pg. for torchft, we'll pass global ranks here so the default pg have a different name on each task group
2025-10-27 13:59:45 -07:00

View File

@ -1574,6 +1574,7 @@ def init_process_group(
group_name: str = "",
pg_options: Optional[Any] = None,
device_id: Optional[Union[torch.device, int]] = None,
_ranks: Optional[list[int]] = None,
) -> None:
"""
Initialize the default distributed process group.
@ -1648,6 +1649,8 @@ def init_process_group(
want to know NCCL initialization error early, you can also use this
field. If an `int` is provided, the API assumes that the accelerator
type at compile time will be used.
_ranks: The ranks in the process group. If provided, the process
group name will be the hash of all the ranks in the group.
.. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
on a system that supports MPI.
@ -1752,7 +1755,10 @@ def init_process_group(
internals of c10d. This means we can ignore the value
they provide as it not exposed in a public way.
"""
group_name = _process_group_name([], use_hashed_name=False)
if _ranks is None or len(_ranks) == 0:
group_name = _process_group_name([], use_hashed_name=False)
else:
group_name = _process_group_name(_ranks, use_hashed_name=True)
if backend == Backend.MPI:
if world_size != -1 or rank != -1:
warnings.warn(