mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 09:34:57 +08:00
In a trunk failure today, we saw the same test running on both trunk and slow shards. The reason is that this test didn't invoke `super().setUp()`, so all the test features like slow and disabled test didn't apply to them. I use Claude to find all test classes with a `setUp()` method that didn't called `super().setUp()` and patch all of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167163 Approved by: https://github.com/malfet
102 lines
3.6 KiB
Python
102 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
# Owner(s): ["oncall: r2p"]
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from torch.distributed.launcher.api import launch_agent, LaunchConfig
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class LauncherApiTest(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Save original environment variable if it exists
|
|
self.original_signals_env = os.environ.get(
|
|
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
|
|
)
|
|
|
|
def tearDown(self):
|
|
# Restore original environment variable
|
|
if self.original_signals_env is not None:
|
|
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
|
|
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
|
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
|
|
|
@patch("torch.distributed.launcher.api.LocalElasticAgent")
|
|
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
|
|
def test_launch_agent_sets_signals_env_var(self, mock_get_handler, mock_agent):
|
|
"""Test that launch_agent sets the TORCHELASTIC_SIGNALS_TO_HANDLE environment variable."""
|
|
# Setup
|
|
config = LaunchConfig(
|
|
min_nodes=1,
|
|
max_nodes=1,
|
|
nproc_per_node=1,
|
|
signals_to_handle="SIGTERM,SIGUSR1,SIGUSR2",
|
|
)
|
|
entrypoint = "dummy_script.py"
|
|
args = []
|
|
|
|
# Make sure the environment variable doesn't exist before the test
|
|
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
|
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
|
|
|
# Mock agent.run() to return a MagicMock
|
|
mock_agent_instance = MagicMock()
|
|
mock_agent_instance.run.return_value = MagicMock(
|
|
is_failed=lambda: False, return_values={}
|
|
)
|
|
mock_agent.return_value = mock_agent_instance
|
|
|
|
# Call launch_agent
|
|
launch_agent(config, entrypoint, args)
|
|
|
|
# Verify that the environment variable was set correctly
|
|
self.assertEqual(
|
|
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"], "SIGTERM,SIGUSR1,SIGUSR2"
|
|
)
|
|
|
|
@patch("torch.distributed.launcher.api.LocalElasticAgent")
|
|
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
|
|
def test_launch_agent_default_signals(self, mock_get_handler, mock_agent):
|
|
"""Test that launch_agent uses the default signals if not specified."""
|
|
# Setup
|
|
config = LaunchConfig(
|
|
min_nodes=1,
|
|
max_nodes=1,
|
|
nproc_per_node=1,
|
|
# Not specifying signals_to_handle, should use default
|
|
)
|
|
entrypoint = "dummy_script.py"
|
|
args = []
|
|
|
|
# Make sure the environment variable doesn't exist before the test
|
|
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
|
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
|
|
|
# Mock agent.run() to return a MagicMock
|
|
mock_agent_instance = MagicMock()
|
|
mock_agent_instance.run.return_value = MagicMock(
|
|
is_failed=lambda: False, return_values={}
|
|
)
|
|
mock_agent.return_value = mock_agent_instance
|
|
|
|
# Call launch_agent
|
|
launch_agent(config, entrypoint, args)
|
|
|
|
# Verify that the environment variable was set to the default value
|
|
self.assertEqual(
|
|
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"],
|
|
"SIGTERM,SIGINT,SIGHUP,SIGQUIT",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|