[C10D] add _summarize_ranks util (#160284)

Prints ranges of ranks succinctly.

e.g.

For a strided list of ranks, summarizes down to start:stop:step
```
0:4096:512
```

Omits step if it's 1
```
0:8
```

Note: endpoints are exclusive. This may not be intuitive to everyone,
but in the first above the last rank is 3584, and in the second it is
7.

Currently, does not support combinations of striding _and_ range.  (e.g.
can not generate a representation like "0:2, 4:6, ..., 12:14".  Is this
needed / useful? If so it could be added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160284
Approved by: https://github.com/XilunWu
This commit is contained in:
Will Constable
2025-08-27 10:33:07 -07:00
committed by PyTorch MergeBot
parent 97a548b640
commit fd60117051
2 changed files with 82 additions and 18 deletions

View File

@ -7,16 +7,20 @@ import torch.distributed as c10d
from torch.distributed.collective_utils import (
_check_rng_sync,
_check_rng_sync_internal,
_summarize_ranks,
all_gather,
broadcast,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.testing import FileCheck
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestCollectiveUtils(MultiProcessTestCase):
@ -163,7 +167,47 @@ class TestCollectiveUtils(MultiProcessTestCase):
log_str = _check_rng_sync(generator, group)
FileCheck().check("Generator desync detected").check("Ranks").check("0").check(
"1"
).check("2-3").run(log_str)
).check("2:4").run(log_str)
class TestUtils(TestCase):
def setUp(self):
super().setUp()
if not c10d.is_initialized():
self.rank = 0
self.world_size = 4096
store = FakeStore()
c10d.init_process_group(
backend="fake",
world_size=self.world_size,
rank=self.rank,
store=store,
)
def tearDown(self):
c10d.destroy_process_group()
def test_summarize_ranks(self):
mesh_dim_names = ("pp", "dp", "tp")
mesh = init_device_mesh("cpu", (8, 64, 8), mesh_dim_names=mesh_dim_names)
ranks_lists = {name: mesh[name].mesh.tolist() for name in mesh_dim_names}
summaries = {
name: _summarize_ranks(ranks_lists[name]) for name in mesh_dim_names
}
self.assertEqual(summaries["pp"], "0:4096:512")
self.assertEqual(summaries["dp"], "0:512:8")
self.assertEqual(summaries["tp"], "0:8")
self.assertEqual(
_summarize_ranks([1, 2, 3, 6, 7, 8, 10, 12, 14, 16]),
"1:4,6:9,10:18:2",
)
self.assertEqual(
_summarize_ranks([1]),
"1",
)
instantiate_parametrized_tests(TestCollectiveUtils)

View File

@ -234,24 +234,44 @@ def all_gather_object_enforce_type(
)
def _summarize_ranks(numbers: Iterable[int]) -> str:
numbers = sorted(numbers)
result = []
current_range_start = numbers[0]
for i in range(1, len(numbers)):
if numbers[i] == numbers[i - 1] + 1:
pass
else:
if current_range_start == numbers[i - 1]:
result.append(str(current_range_start))
def _summarize_ranks(ranks: Iterable[int]) -> str:
ranks = sorted(ranks)
assert min(ranks) >= 0, "ranks should all be positive"
assert len(set(ranks)) == len(ranks), "ranks should not contain duplicates"
curr: Optional[Union[int, range]] = None
ranges = []
while ranks:
x = ranks.pop(0)
if curr is None:
curr = x
elif isinstance(curr, int):
if x == curr + 1:
curr = range(curr, x + 1, 1)
else:
result.append(f"{current_range_start}-{numbers[i - 1]}")
current_range_start = numbers[i]
if current_range_start == numbers[-1]:
result.append(str(current_range_start))
else:
result.append(f"{current_range_start}-{numbers[-1]}")
return ", ".join(result)
step = x - curr
curr = range(curr, x + step, step)
else:
assert isinstance(curr, range)
if x == curr.stop:
curr = range(curr.start, curr.stop + curr.step, curr.step)
else:
ranges.append(curr)
curr = x
if isinstance(curr, int):
ranges.append(range(curr, curr + 1, 1))
elif isinstance(curr, range):
ranges.append(curr)
result = []
for r in ranges:
if len(r) == 1:
result.append(f"{r.start}")
elif r.step == 1:
result.append(f"{r.start}:{r.stop}")
else:
result.append(f"{r.start}:{r.stop}:{r.step}")
return ",".join(result)
def _check_philox_rng_sync(