mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555 Approved by: https://github.com/ezyang ghstack dependencies: #144551, #144554
215 lines
7.2 KiB
Python
215 lines
7.2 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
from datetime import timedelta
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
# simple example of user code that takes the base class ControlCollectives
|
|
# and executes multiple different collectives
|
|
def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int:
|
|
timeout = timedelta(seconds=10)
|
|
# first a barrier
|
|
collectives.barrier("1", timeout, True)
|
|
# then an all_sum
|
|
out = collectives.all_sum("2", rank, timeout)
|
|
return out
|
|
|
|
|
|
class TestCollectives(TestCase):
|
|
def test_barrier(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 2
|
|
|
|
def f(rank: int) -> None:
|
|
collectives = dist._StoreCollectives(store, rank, world_size)
|
|
collectives.barrier("foo", timedelta(seconds=10), True)
|
|
|
|
with ThreadPool(world_size) as pool:
|
|
pool.map(f, range(world_size))
|
|
|
|
def test_broadcast(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(seconds=10)
|
|
|
|
def f(rank: int) -> None:
|
|
collectives = dist._StoreCollectives(store, rank, world_size)
|
|
if rank == 2:
|
|
collectives.broadcast_send("foo", b"data", timeout)
|
|
else:
|
|
out = collectives.broadcast_recv("foo", timeout)
|
|
self.assertEqual(out, b"data")
|
|
|
|
with ThreadPool(world_size) as pool:
|
|
pool.map(f, range(world_size))
|
|
|
|
def test_gather(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(seconds=10)
|
|
|
|
def f(rank: int) -> None:
|
|
collectives = dist._StoreCollectives(store, rank, world_size)
|
|
if rank == 2:
|
|
out = collectives.gather_recv("foo", str(rank), timeout)
|
|
self.assertEqual(out, [b"0", b"1", b"2", b"3"])
|
|
else:
|
|
collectives.gather_send("foo", str(rank), timeout)
|
|
|
|
with ThreadPool(world_size) as pool:
|
|
pool.map(f, range(world_size))
|
|
|
|
def test_scatter(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(seconds=10)
|
|
|
|
def f(rank: int) -> None:
|
|
collectives = dist._StoreCollectives(store, rank, world_size)
|
|
if rank == 2:
|
|
out = collectives.scatter_send(
|
|
"foo", [str(i) for i in range(world_size)], timeout
|
|
)
|
|
else:
|
|
out = collectives.scatter_recv("foo", timeout)
|
|
self.assertEqual(out, str(rank).encode())
|
|
|
|
with ThreadPool(world_size) as pool:
|
|
pool.map(f, range(world_size))
|
|
|
|
def test_all_sum(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(seconds=10)
|
|
|
|
def f(rank: int) -> None:
|
|
collectives = dist._StoreCollectives(store, rank, world_size)
|
|
out = collectives.all_sum("foo", rank, timeout)
|
|
self.assertEqual(out, sum(range(world_size)))
|
|
|
|
with ThreadPool(world_size) as pool:
|
|
pool.map(f, range(world_size))
|
|
|
|
def test_broadcast_timeout(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(milliseconds=1)
|
|
collectives = dist._StoreCollectives(store, 1, world_size)
|
|
with self.assertRaisesRegex(Exception, "Wait timeout"):
|
|
collectives.broadcast_recv("foo", timeout)
|
|
|
|
def test_gather_timeout(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(milliseconds=1)
|
|
collectives = dist._StoreCollectives(store, 1, world_size)
|
|
with self.assertRaisesRegex(
|
|
Exception, "gather failed -- missing ranks: 0, 2, 3"
|
|
):
|
|
collectives.gather_recv("foo", "data", timeout)
|
|
|
|
def test_scatter_timeout(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(milliseconds=1)
|
|
collectives = dist._StoreCollectives(store, 1, world_size)
|
|
with self.assertRaisesRegex(Exception, "Wait timeout"):
|
|
collectives.scatter_recv("foo", timeout)
|
|
|
|
def test_all_gather_timeout(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(milliseconds=1)
|
|
collectives = dist._StoreCollectives(store, 1, world_size)
|
|
with self.assertRaisesRegex(
|
|
Exception, "all_gather failed -- missing ranks: 0, 2, 3"
|
|
):
|
|
collectives.all_gather("foo", "data", timeout)
|
|
|
|
def test_barrier_timeout(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(milliseconds=1)
|
|
collectives = dist._StoreCollectives(store, 1, world_size)
|
|
with self.assertRaisesRegex(
|
|
Exception, "barrier failed -- missing ranks: 0, 2, 3"
|
|
):
|
|
collectives.barrier("foo", timeout, True)
|
|
|
|
def test_all_sum_timeout(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
world_size = 4
|
|
timeout = timedelta(milliseconds=1)
|
|
collectives = dist._StoreCollectives(store, 1, world_size)
|
|
with self.assertRaisesRegex(
|
|
Exception, "barrier failed -- missing ranks: 0, 2, 3"
|
|
):
|
|
collectives.all_sum("foo", 1, timeout)
|
|
|
|
def test_unique(self) -> None:
|
|
store = dist.HashStore()
|
|
|
|
collectives = dist._StoreCollectives(store, 1, 1)
|
|
collectives.broadcast_send("foo", "bar")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.broadcast_send("foo", "bar")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.broadcast_recv("foo")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.gather_send("foo", "bar")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.gather_recv("foo", "asdf")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.scatter_send("foo", ["asdf"])
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.scatter_recv("foo")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.all_gather("foo", "bar")
|
|
|
|
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
|
collectives.all_sum("foo", 2)
|
|
|
|
def test_simple_user_func(self) -> None:
|
|
store = dist.HashStore()
|
|
world_size = 4
|
|
|
|
def f(rank: int) -> None:
|
|
# user need to create child collectives
|
|
# but simple_user_func do not need to be changed for different child collectives
|
|
store_collectives = dist._StoreCollectives(store, rank, world_size)
|
|
out = simple_user_func(store_collectives, rank)
|
|
self.assertEqual(out, sum(range(world_size)))
|
|
|
|
with ThreadPool(world_size) as pool:
|
|
pool.map(f, range(world_size))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
assert not torch.cuda._initialized, (
|
|
"test_distributed must not have initialized CUDA context on main process"
|
|
)
|
|
|
|
run_tests()
|