Files
pytorch/test/distributed/test_control_collectives.py
2025-06-24 04:53:54 +00:00

215 lines
7.2 KiB
Python

# Owner(s): ["oncall: distributed"]
from datetime import timedelta
from multiprocessing.pool import ThreadPool
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests, TestCase
# simple example of user code that takes the base class ControlCollectives
# and executes multiple different collectives
def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int:
timeout = timedelta(seconds=10)
# first a barrier
collectives.barrier("1", timeout, True)
# then an all_sum
out = collectives.all_sum("2", rank, timeout)
return out
class TestCollectives(TestCase):
def test_barrier(self) -> None:
store = dist.HashStore()
world_size = 2
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
collectives.barrier("foo", timedelta(seconds=10), True)
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_broadcast(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
collectives.broadcast_send("foo", b"data", timeout)
else:
out = collectives.broadcast_recv("foo", timeout)
self.assertEqual(out, b"data")
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_gather(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.gather_recv("foo", str(rank), timeout)
self.assertEqual(out, [b"0", b"1", b"2", b"3"])
else:
collectives.gather_send("foo", str(rank), timeout)
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_scatter(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.scatter_send(
"foo", [str(i) for i in range(world_size)], timeout
)
else:
out = collectives.scatter_recv("foo", timeout)
self.assertEqual(out, str(rank).encode())
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_all_sum(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
out = collectives.all_sum("foo", rank, timeout)
self.assertEqual(out, sum(range(world_size)))
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_broadcast_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.broadcast_recv("foo", timeout)
def test_gather_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "gather failed -- missing ranks: 0, 2, 3"
):
collectives.gather_recv("foo", "data", timeout)
def test_scatter_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.scatter_recv("foo", timeout)
def test_all_gather_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "all_gather failed -- missing ranks: 0, 2, 3"
):
collectives.all_gather("foo", "data", timeout)
def test_barrier_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.barrier("foo", timeout, True)
def test_all_sum_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.all_sum("foo", 1, timeout)
def test_unique(self) -> None:
store = dist.HashStore()
collectives = dist._StoreCollectives(store, 1, 1)
collectives.broadcast_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_recv("foo")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_recv("foo", "asdf")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_send("foo", ["asdf"])
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_recv("foo")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_gather("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_sum("foo", 2)
def test_simple_user_func(self) -> None:
store = dist.HashStore()
world_size = 4
def f(rank: int) -> None:
# user need to create child collectives
# but simple_user_func do not need to be changed for different child collectives
store_collectives = dist._StoreCollectives(store, rank, world_size)
out = simple_user_func(store_collectives, rank)
self.assertEqual(out, sum(range(world_size)))
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
if __name__ == "__main__":
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()