mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[C10D] add _summarize_ranks util (#160284)
Prints ranges of ranks succinctly. e.g. For a strided list of ranks, summarizes down to start:stop:step ``` 0:4096:512 ``` Omits step if it's 1 ``` 0:8 ``` Note: endpoints are exclusive. This may not be intuitive to everyone, but in the first above the last rank is 3584, and in the second it is 7. Currently, does not support combinations of striding _and_ range. (e.g. can not generate a representation like "0:2, 4:6, ..., 12:14". Is this needed / useful? If so it could be added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160284 Approved by: https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
97a548b640
commit
fd60117051
@ -7,16 +7,20 @@ import torch.distributed as c10d
|
||||
from torch.distributed.collective_utils import (
|
||||
_check_rng_sync,
|
||||
_check_rng_sync_internal,
|
||||
_summarize_ranks,
|
||||
all_gather,
|
||||
broadcast,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class TestCollectiveUtils(MultiProcessTestCase):
|
||||
@ -163,7 +167,47 @@ class TestCollectiveUtils(MultiProcessTestCase):
|
||||
log_str = _check_rng_sync(generator, group)
|
||||
FileCheck().check("Generator desync detected").check("Ranks").check("0").check(
|
||||
"1"
|
||||
).check("2-3").run(log_str)
|
||||
).check("2:4").run(log_str)
|
||||
|
||||
|
||||
class TestUtils(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
if not c10d.is_initialized():
|
||||
self.rank = 0
|
||||
self.world_size = 4096
|
||||
|
||||
store = FakeStore()
|
||||
c10d.init_process_group(
|
||||
backend="fake",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
c10d.destroy_process_group()
|
||||
|
||||
def test_summarize_ranks(self):
|
||||
mesh_dim_names = ("pp", "dp", "tp")
|
||||
mesh = init_device_mesh("cpu", (8, 64, 8), mesh_dim_names=mesh_dim_names)
|
||||
ranks_lists = {name: mesh[name].mesh.tolist() for name in mesh_dim_names}
|
||||
summaries = {
|
||||
name: _summarize_ranks(ranks_lists[name]) for name in mesh_dim_names
|
||||
}
|
||||
self.assertEqual(summaries["pp"], "0:4096:512")
|
||||
self.assertEqual(summaries["dp"], "0:512:8")
|
||||
self.assertEqual(summaries["tp"], "0:8")
|
||||
|
||||
self.assertEqual(
|
||||
_summarize_ranks([1, 2, 3, 6, 7, 8, 10, 12, 14, 16]),
|
||||
"1:4,6:9,10:18:2",
|
||||
)
|
||||
self.assertEqual(
|
||||
_summarize_ranks([1]),
|
||||
"1",
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCollectiveUtils)
|
||||
|
@ -234,24 +234,44 @@ def all_gather_object_enforce_type(
|
||||
)
|
||||
|
||||
|
||||
def _summarize_ranks(numbers: Iterable[int]) -> str:
|
||||
numbers = sorted(numbers)
|
||||
result = []
|
||||
current_range_start = numbers[0]
|
||||
for i in range(1, len(numbers)):
|
||||
if numbers[i] == numbers[i - 1] + 1:
|
||||
pass
|
||||
else:
|
||||
if current_range_start == numbers[i - 1]:
|
||||
result.append(str(current_range_start))
|
||||
def _summarize_ranks(ranks: Iterable[int]) -> str:
|
||||
ranks = sorted(ranks)
|
||||
assert min(ranks) >= 0, "ranks should all be positive"
|
||||
assert len(set(ranks)) == len(ranks), "ranks should not contain duplicates"
|
||||
curr: Optional[Union[int, range]] = None
|
||||
ranges = []
|
||||
while ranks:
|
||||
x = ranks.pop(0)
|
||||
if curr is None:
|
||||
curr = x
|
||||
elif isinstance(curr, int):
|
||||
if x == curr + 1:
|
||||
curr = range(curr, x + 1, 1)
|
||||
else:
|
||||
result.append(f"{current_range_start}-{numbers[i - 1]}")
|
||||
current_range_start = numbers[i]
|
||||
if current_range_start == numbers[-1]:
|
||||
result.append(str(current_range_start))
|
||||
else:
|
||||
result.append(f"{current_range_start}-{numbers[-1]}")
|
||||
return ", ".join(result)
|
||||
step = x - curr
|
||||
curr = range(curr, x + step, step)
|
||||
else:
|
||||
assert isinstance(curr, range)
|
||||
if x == curr.stop:
|
||||
curr = range(curr.start, curr.stop + curr.step, curr.step)
|
||||
else:
|
||||
ranges.append(curr)
|
||||
curr = x
|
||||
|
||||
if isinstance(curr, int):
|
||||
ranges.append(range(curr, curr + 1, 1))
|
||||
elif isinstance(curr, range):
|
||||
ranges.append(curr)
|
||||
|
||||
result = []
|
||||
for r in ranges:
|
||||
if len(r) == 1:
|
||||
result.append(f"{r.start}")
|
||||
elif r.step == 1:
|
||||
result.append(f"{r.start}:{r.stop}")
|
||||
else:
|
||||
result.append(f"{r.start}:{r.stop}:{r.step}")
|
||||
return ",".join(result)
|
||||
|
||||
|
||||
def _check_philox_rng_sync(
|
||||
|
Reference in New Issue
Block a user