Add __main__ guards to distributed tests (#154628)

This is the first PR of a series in an attempt to re-submit #134592 as smaller PRs.

In distributed tests:

- Ensure all files which should call run_tests do call run_tests.
- Raise a RuntimeError on tests which have been disabled (not run)
- Remove any remaining uses of "unittest.main()""

Cc @wconstab @clee2000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154628
Approved by: https://github.com/Skylion007
This commit is contained in:
Anthony Barbier
2025-06-04 14:39:54 +00:00
committed by PyTorch MergeBot
parent c8d44a2296
commit 3f34d26040
7 changed files with 39 additions and 4 deletions

View File

@ -34,7 +34,6 @@ from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
from torch.distributed.elastic.rendezvous.api import RendezvousGracefulExitError
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import run_tests
def do_nothing():
@ -650,4 +649,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
if __name__ == "__main__":
run_tests()
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -1466,3 +1466,10 @@ class LocalElasticAgentTest(unittest.TestCase):
)
def test_rank_restart_after_failure(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.fail_rank_one_once)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -141,4 +141,7 @@ class RedirectsTest(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -411,3 +411,10 @@ class ElasticLaunchTest(unittest.TestCase):
launch_agent(config, simple_rank_scale, [])
rdzv_handler_mock.shutdown.assert_called_once()
record_event_mock.assert_called_once()
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -84,3 +84,10 @@ class LaunchTest(unittest.TestCase):
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -8,7 +8,7 @@ import torch
import torch.distributed as c10d
import torch.multiprocessing as mp
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import load_tests
from torch.testing._internal.common_utils import load_tests, run_tests
# Torch distributed.nn is not available in windows
@ -246,3 +246,7 @@ class TestDistributedNNFunctions(MultiProcessTestCase):
z.backward()
x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
self.assertEqual(x.grad, x_s)
if __name__ == "__main__":
run_tests()

View File

@ -5,6 +5,7 @@ from unittest import mock
import torch.distributed as c10d
from torch.distributed.collective_utils import all_gather, broadcast
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import run_tests
class TestCollectiveUtils(MultiProcessTestCase):
@ -114,3 +115,7 @@ class TestCollectiveUtils(MultiProcessTestCase):
expected_exception = "test exception"
with self.assertRaisesRegex(Exception, expected_exception):
all_gather(data_or_fn=func)
if __name__ == "__main__":
run_tests()