Files
pytorch/test/distributed/test_control_collectives.py
Tristan Rice 4b2ae2ac33 c10d: add Collectives abstraction (#125978)
This adds a new `Collectives` API for doing distributed collectives operations. This is intended to replace the [current Elastic store abstraction](https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/utils/store.py) with more performant and debugable primitives.

Design doc: https://docs.google.com/document/d/147KcKJXEHvk1Q6tISLbJVvLejHg_1kIhBQeu-8RQxhY/edit

The standard implementation is using `StoreCollectives` but other more performant backends will be added in a follow up PR.

Test plan:

```
python test/distributed/test_collectives.py -v
```

This tests both functionality using multiple threads as well as timeout behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125978
Approved by: https://github.com/shuqiangzhang
2024-05-17 05:09:11 +00:00

190 lines
6.3 KiB
Python

# Owner(s): ["oncall: distributed"]
from datetime import timedelta
from multiprocessing.pool import ThreadPool
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests, TestCase
class TestCollectives(TestCase):
def test_barrier(self) -> None:
store = dist.HashStore()
world_size = 2
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
collectives.barrier("foo", timedelta(seconds=10), True)
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_broadcast(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
collectives.broadcast_send("foo", b"data", timeout)
else:
out = collectives.broadcast_recv("foo", timeout)
self.assertEqual(out, b"data")
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_gather(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.gather_recv("foo", str(rank), timeout)
self.assertEqual(out, [b"0", b"1", b"2", b"3"])
else:
collectives.gather_send("foo", str(rank), timeout)
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_scatter(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.scatter_send(
"foo", [str(i) for i in range(world_size)], timeout
)
else:
out = collectives.scatter_recv("foo", timeout)
self.assertEqual(out, str(rank).encode())
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_all_sum(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
out = collectives.all_sum("foo", rank, timeout)
self.assertEqual(out, sum(range(world_size)))
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_broadcast_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.broadcast_recv("foo", timeout)
def test_gather_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "gather failed -- missing ranks: 0, 2, 3"
):
collectives.gather_recv("foo", "data", timeout)
def test_scatter_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.scatter_recv("foo", timeout)
def test_all_gather_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "all_gather failed -- missing ranks: 0, 2, 3"
):
collectives.all_gather("foo", "data", timeout)
def test_barrier_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.barrier("foo", timeout, True)
def test_all_sum_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.all_sum("foo", 1, timeout)
def test_unique(self) -> None:
store = dist.HashStore()
collectives = dist._StoreCollectives(store, 1, 1)
collectives.broadcast_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_recv("foo")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_recv("foo", "asdf")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_send("foo", ["asdf"])
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_recv("foo")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_gather("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_sum("foo", 2)
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()