mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is the first PR of a series in an attempt to re-submit #134592 as smaller PRs. In distributed tests: - Ensure all files which should call run_tests do call run_tests. - Raise a RuntimeError on tests which have been disabled (not run) - Remove any remaining uses of "unittest.main()"" Cc @wconstab @clee2000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154628 Approved by: https://github.com/Skylion007
122 lines
3.7 KiB
Python
122 lines
3.7 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
from unittest import mock
|
|
|
|
import torch.distributed as c10d
|
|
from torch.distributed.collective_utils import all_gather, broadcast
|
|
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
class TestCollectiveUtils(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self) -> None:
|
|
super().tearDown()
|
|
|
|
def opts(self, threads=2):
|
|
opts = c10d.ProcessGroupGloo._Options()
|
|
opts._timeout = 50.0
|
|
opts._threads = threads
|
|
return opts
|
|
|
|
def test_broadcast_result(self) -> None:
|
|
"""
|
|
Basic unit test for broadcast using a process group of default world size.
|
|
"""
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
c10d.init_process_group(
|
|
backend="gloo", store=store, rank=self.rank, world_size=self.world_size
|
|
)
|
|
pg = c10d.new_group(pg_options=self.opts())
|
|
|
|
func = mock.MagicMock()
|
|
func.return_value = pg.rank()
|
|
|
|
res = broadcast(data_or_fn=func, rank=0, pg=pg)
|
|
assert res == 0, f"Expect res to be 0 (got {res})"
|
|
|
|
if pg.rank() == 0:
|
|
func.assert_called_once()
|
|
else:
|
|
func.assert_not_called()
|
|
|
|
func.reset_mock()
|
|
|
|
res = broadcast(data_or_fn=func, rank=1, pg=pg)
|
|
assert res == 1, f"Expect res to be 1 (got {res})"
|
|
|
|
if pg.rank() == 1:
|
|
func.assert_called_once()
|
|
else:
|
|
func.assert_not_called()
|
|
|
|
def test_broadcast_result_no_pg(self) -> None:
|
|
"""
|
|
Ensure broadcast has no dependency on torch.distributed when run in single process.
|
|
"""
|
|
func = mock.MagicMock()
|
|
broadcast(data_or_fn=func, rank=0)
|
|
func.assert_called_once()
|
|
|
|
def test_broadcast_result_raises_exceptions_from_func(
|
|
self,
|
|
) -> None:
|
|
"""
|
|
Ensure broadcast exception is propagated properly.
|
|
"""
|
|
# no process group
|
|
func = mock.MagicMock()
|
|
exc = Exception("test exception")
|
|
func.side_effect = exc
|
|
expected_exception = "test exception"
|
|
with self.assertRaisesRegex(Exception, expected_exception):
|
|
broadcast(data_or_fn=func, rank=0)
|
|
|
|
def test_all_gather_result(self) -> None:
|
|
"""
|
|
Basic unit test for all_gather using a process group of default world size.
|
|
"""
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
c10d.init_process_group(
|
|
backend="gloo", store=store, rank=self.rank, world_size=self.world_size
|
|
)
|
|
pg = c10d.new_group(pg_options=self.opts())
|
|
|
|
func = mock.MagicMock()
|
|
func.return_value = pg.rank()
|
|
|
|
res = all_gather(data_or_fn=func, pg=pg)
|
|
func.assert_called_once()
|
|
assert res == list(
|
|
range(self.world_size)
|
|
), f"Expect res to be list of 0 through {self.world_size} (got {res})"
|
|
|
|
def test_all_gather_result_no_pg(self) -> None:
|
|
"""
|
|
Ensure all_gather has no dependency on torch.distributed when run in single process.
|
|
"""
|
|
func = mock.MagicMock()
|
|
all_gather(data_or_fn=func)
|
|
func.assert_called_once()
|
|
|
|
def test_all_gather_result_raises_exceptions_from_func(
|
|
self,
|
|
) -> None:
|
|
"""
|
|
Ensure all_gather exception is propagated properly.
|
|
"""
|
|
# no process group
|
|
func = mock.MagicMock()
|
|
exc = Exception("test exception")
|
|
func.side_effect = exc
|
|
expected_exception = "test exception"
|
|
with self.assertRaisesRegex(Exception, expected_exception):
|
|
all_gather(data_or_fn=func)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|