mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add __main__ guards to distributed tests (#154628)
This is the first PR of a series in an attempt to re-submit #134592 as smaller PRs. In distributed tests: - Ensure all files which should call run_tests do call run_tests. - Raise a RuntimeError on tests which have been disabled (not run) - Remove any remaining uses of "unittest.main()"" Cc @wconstab @clee2000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154628 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8d44a2296
commit
3f34d26040
@ -34,7 +34,6 @@ from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
|
||||
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
|
||||
from torch.distributed.elastic.rendezvous.api import RendezvousGracefulExitError
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
def do_nothing():
|
||||
@ -650,4 +649,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -1466,3 +1466,10 @@ class LocalElasticAgentTest(unittest.TestCase):
|
||||
)
|
||||
def test_rank_restart_after_failure(self):
|
||||
self.run_test_with_backend(backend="c10d", test_to_run=self.fail_rank_one_once)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -141,4 +141,7 @@ class RedirectsTest(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -411,3 +411,10 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
launch_agent(config, simple_rank_scale, [])
|
||||
rdzv_handler_mock.shutdown.assert_called_once()
|
||||
record_event_mock.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -84,3 +84,10 @@ class LaunchTest(unittest.TestCase):
|
||||
self.assertSetEqual(
|
||||
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.multiprocessing as mp
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import load_tests
|
||||
from torch.testing._internal.common_utils import load_tests, run_tests
|
||||
|
||||
|
||||
# Torch distributed.nn is not available in windows
|
||||
@ -246,3 +246,7 @@ class TestDistributedNNFunctions(MultiProcessTestCase):
|
||||
z.backward()
|
||||
x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -5,6 +5,7 @@ from unittest import mock
|
||||
import torch.distributed as c10d
|
||||
from torch.distributed.collective_utils import all_gather, broadcast
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
class TestCollectiveUtils(MultiProcessTestCase):
|
||||
@ -114,3 +115,7 @@ class TestCollectiveUtils(MultiProcessTestCase):
|
||||
expected_exception = "test exception"
|
||||
with self.assertRaisesRegex(Exception, expected_exception):
|
||||
all_gather(data_or_fn=func)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user