mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
154849 Add support to handle IGUSR1 and SIGUSR2 in multiprocessing (#160690)
Fixes #154849 This change addresses the request to add support for SIGUSR1 and SIGUSR2 signals in torchrun for SLURM environments. Changes supports these signals through the configurable `TORCHELASTIC_SIGNALS_TO_HANDLE` environment variable and signals_to_handle parameter from laucher api Tests: For validations purpose: test_signal_handling.py, simple_test_api_signal_handling.py, Unit Tests: for launcher changes:launcher/test_api.py for api changes: multiprocessing/test_api.py E2E: test_run.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/160690 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d66a3b894
commit
b498299953
331
test/distributed/elastic/multiprocessing/test_api.py
Normal file
331
test/distributed/elastic/multiprocessing/test_api.py
Normal file
@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import signal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.api import (
|
||||
_terminate_process_handler,
|
||||
PContext,
|
||||
SignalException,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class SignalHandlingTest(TestCase):
|
||||
def setUp(self):
|
||||
# Save original environment variable if it exists
|
||||
self.original_signals_env = os.environ.get(
|
||||
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original environment variable
|
||||
if self.original_signals_env is not None:
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
|
||||
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
def test_terminate_process_handler(self):
|
||||
"""Test that the terminate process handler raises SignalException with the correct signal."""
|
||||
signum = signal.SIGTERM
|
||||
with self.assertRaises(SignalException) as cm:
|
||||
_terminate_process_handler(signum, None)
|
||||
|
||||
self.assertEqual(cm.exception.sigval, signal.SIGTERM)
|
||||
# The signal is represented as a number in the string representation
|
||||
self.assertIn(f"Process {os.getpid()} got signal: {signum}", str(cm.exception))
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.logger")
|
||||
def test_start_registers_default_signals(
|
||||
self, mock_logger, mock_signal, mock_threading
|
||||
):
|
||||
"""Test that the start method registers the default signals."""
|
||||
# Setup
|
||||
mock_threading.current_thread.return_value = (
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
|
||||
# Remove environment variable if it exists to test default behavior
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
||||
# Verify that the signal handler was registered for the default signals
|
||||
expected_signals = ["SIGTERM", "SIGINT", "SIGHUP", "SIGQUIT"]
|
||||
|
||||
# Count the number of calls to signal.signal
|
||||
signal_calls = 0
|
||||
for call in mock_signal.signal.call_args_list:
|
||||
args, _ = call
|
||||
sig, handler = args
|
||||
signal_calls += 1
|
||||
# Verify the handler is our _terminate_process_handler
|
||||
self.assertEqual(handler, _terminate_process_handler)
|
||||
|
||||
# Verify we registered the expected number of signals
|
||||
self.assertEqual(signal_calls, len(expected_signals))
|
||||
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.logger")
|
||||
def test_start_registers_custom_signals(
|
||||
self, mock_logger, mock_signal, mock_threading
|
||||
):
|
||||
"""Test that the start method registers custom signals from the environment variable."""
|
||||
# Setup
|
||||
mock_threading.current_thread.return_value = (
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
|
||||
# Set custom signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
||||
# Verify that the signal handler was registered for the custom signals
|
||||
expected_signals = ["SIGTERM", "SIGUSR1", "SIGUSR2"]
|
||||
|
||||
# Count the number of calls to signal.signal
|
||||
signal_calls = 0
|
||||
for call in mock_signal.signal.call_args_list:
|
||||
args, _ = call
|
||||
sig, handler = args
|
||||
signal_calls += 1
|
||||
# Verify the handler is our _terminate_process_handler
|
||||
self.assertEqual(handler, _terminate_process_handler)
|
||||
|
||||
# Verify we registered the expected number of signals
|
||||
self.assertEqual(signal_calls, len(expected_signals))
|
||||
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.logger")
|
||||
def test_start_handles_invalid_signals(
|
||||
self, mock_logger, mock_signal, mock_threading
|
||||
):
|
||||
"""Test that the start method handles invalid signals gracefully."""
|
||||
# Setup
|
||||
mock_threading.current_thread.return_value = (
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
|
||||
# Set invalid signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
|
||||
|
||||
# Mock the signal module to not have the INVALID_SIGNAL attribute
|
||||
# but have SIGTERM
|
||||
mock_signal.SIGTERM = signal.SIGTERM
|
||||
# Remove INVALID_SIGNAL attribute if it exists
|
||||
if hasattr(mock_signal, "INVALID_SIGNAL"):
|
||||
delattr(mock_signal, "INVALID_SIGNAL")
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
||||
# Verify that the warning was logged for the invalid signal
|
||||
# The exact message may vary, so let's check if warning was called with INVALID_SIGNAL
|
||||
warning_calls = [
|
||||
call
|
||||
for call in mock_logger.warning.call_args_list
|
||||
if "INVALID_SIGNAL" in str(call)
|
||||
]
|
||||
self.assertTrue(len(warning_calls) > 0, "Expected warning about INVALID_SIGNAL")
|
||||
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.logger")
|
||||
def test_start_handles_windows_signals(
|
||||
self, mock_logger, mock_signal, mock_threading
|
||||
):
|
||||
"""Test that the start method handles Windows-specific signal behavior."""
|
||||
# Setup
|
||||
mock_threading.current_thread.return_value = (
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
|
||||
# Set signals including ones not supported on Windows
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
|
||||
|
||||
# Mock signal attributes
|
||||
mock_signal.SIGTERM = signal.SIGTERM
|
||||
mock_signal.SIGHUP = signal.SIGHUP
|
||||
mock_signal.SIGUSR1 = signal.SIGUSR1
|
||||
|
||||
# Mock IS_WINDOWS to be True
|
||||
with patch("torch.distributed.elastic.multiprocessing.api.IS_WINDOWS", True):
|
||||
# Mock signal.signal to raise RuntimeError for Windows-unsupported signals
|
||||
def signal_side_effect(sig, handler):
|
||||
if sig in [signal.SIGHUP, signal.SIGUSR1]:
|
||||
raise RuntimeError("Signal not supported on Windows")
|
||||
|
||||
mock_signal.signal.side_effect = signal_side_effect
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
||||
# Verify that the info was logged for the unsupported signals
|
||||
# Check if any info calls contain the expected messages
|
||||
info_calls = [str(call) for call in mock_logger.info.call_args_list]
|
||||
sighup_logged = any(
|
||||
"SIGHUP" in call and "Windows" in call for call in info_calls
|
||||
)
|
||||
sigusr1_logged = any(
|
||||
"SIGUSR1" in call and "Windows" in call for call in info_calls
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
sighup_logged,
|
||||
f"Expected SIGHUP Windows message in info calls: {info_calls}",
|
||||
)
|
||||
self.assertTrue(
|
||||
sigusr1_logged,
|
||||
f"Expected SIGUSR1 Windows message in info calls: {info_calls}",
|
||||
)
|
||||
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.logger")
|
||||
def test_start_not_main_thread(self, mock_logger, mock_threading):
|
||||
"""Test that the start method warns when not called from the main thread."""
|
||||
# Setup
|
||||
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
|
||||
mock_threading.main_thread.return_value = MagicMock()
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
||||
# Verify that the warning was logged
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Failed to register signal handlers since torchelastic is running on a child thread. "
|
||||
"This could lead to orphaned worker processes if the torchrun is terminated."
|
||||
)
|
||||
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.logger")
|
||||
def test_start_supports_sigusr1_and_sigusr2(
|
||||
self, mock_logger, mock_signal, mock_threading
|
||||
):
|
||||
"""Test that the start method properly supports SIGUSR1 and SIGUSR2 signals."""
|
||||
# Setup
|
||||
mock_threading.current_thread.return_value = (
|
||||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
|
||||
# Set environment variable to include SIGUSR1 and SIGUSR2
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
|
||||
|
||||
# Mock signal attributes to have SIGUSR1 and SIGUSR2
|
||||
mock_signal.SIGUSR1 = signal.SIGUSR1
|
||||
mock_signal.SIGUSR2 = signal.SIGUSR2
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
||||
# Verify that signal.signal was called for both SIGUSR1 and SIGUSR2
|
||||
signal_calls = mock_signal.signal.call_args_list
|
||||
registered_signals = [
|
||||
call[0][0] for call in signal_calls
|
||||
] # Extract the signal from each call
|
||||
|
||||
# Verify both SIGUSR1 and SIGUSR2 were registered
|
||||
self.assertIn(
|
||||
signal.SIGUSR1, registered_signals, "SIGUSR1 should be registered"
|
||||
)
|
||||
self.assertIn(
|
||||
signal.SIGUSR2, registered_signals, "SIGUSR2 should be registered"
|
||||
)
|
||||
|
||||
# Verify the correct handler was registered for both signals
|
||||
for call in signal_calls:
|
||||
sig, handler = call[0]
|
||||
if sig in [signal.SIGUSR1, signal.SIGUSR2]:
|
||||
self.assertEqual(
|
||||
handler,
|
||||
_terminate_process_handler,
|
||||
f"Signal {sig} should use _terminate_process_handler",
|
||||
)
|
||||
|
||||
# Verify that info messages were logged for successful registration
|
||||
info_calls = [str(call) for call in mock_logger.info.call_args_list]
|
||||
sigusr1_logged = any(
|
||||
"SIGUSR1" in call and "Registered signal handler" in call
|
||||
for call in info_calls
|
||||
)
|
||||
sigusr2_logged = any(
|
||||
"SIGUSR2" in call and "Registered signal handler" in call
|
||||
for call in info_calls
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
sigusr1_logged,
|
||||
f"Expected SIGUSR1 registration message in info calls: {info_calls}",
|
||||
)
|
||||
self.assertTrue(
|
||||
sigusr2_logged,
|
||||
f"Expected SIGUSR2 registration message in info calls: {info_calls}",
|
||||
)
|
||||
|
||||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
100
test/distributed/launcher/test_api.py
Normal file
100
test/distributed/launcher/test_api.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from torch.distributed.launcher.api import launch_agent, LaunchConfig
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class LauncherApiTest(TestCase):
|
||||
def setUp(self):
|
||||
# Save original environment variable if it exists
|
||||
self.original_signals_env = os.environ.get(
|
||||
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original environment variable
|
||||
if self.original_signals_env is not None:
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
|
||||
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
@patch("torch.distributed.launcher.api.LocalElasticAgent")
|
||||
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
|
||||
def test_launch_agent_sets_signals_env_var(self, mock_get_handler, mock_agent):
|
||||
"""Test that launch_agent sets the TORCHELASTIC_SIGNALS_TO_HANDLE environment variable."""
|
||||
# Setup
|
||||
config = LaunchConfig(
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
nproc_per_node=1,
|
||||
signals_to_handle="SIGTERM,SIGUSR1,SIGUSR2",
|
||||
)
|
||||
entrypoint = "dummy_script.py"
|
||||
args = []
|
||||
|
||||
# Make sure the environment variable doesn't exist before the test
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
# Mock agent.run() to return a MagicMock
|
||||
mock_agent_instance = MagicMock()
|
||||
mock_agent_instance.run.return_value = MagicMock(
|
||||
is_failed=lambda: False, return_values={}
|
||||
)
|
||||
mock_agent.return_value = mock_agent_instance
|
||||
|
||||
# Call launch_agent
|
||||
launch_agent(config, entrypoint, args)
|
||||
|
||||
# Verify that the environment variable was set correctly
|
||||
self.assertEqual(
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"], "SIGTERM,SIGUSR1,SIGUSR2"
|
||||
)
|
||||
|
||||
@patch("torch.distributed.launcher.api.LocalElasticAgent")
|
||||
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
|
||||
def test_launch_agent_default_signals(self, mock_get_handler, mock_agent):
|
||||
"""Test that launch_agent uses the default signals if not specified."""
|
||||
# Setup
|
||||
config = LaunchConfig(
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
nproc_per_node=1,
|
||||
# Not specifying signals_to_handle, should use default
|
||||
)
|
||||
entrypoint = "dummy_script.py"
|
||||
args = []
|
||||
|
||||
# Make sure the environment variable doesn't exist before the test
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
# Mock agent.run() to return a MagicMock
|
||||
mock_agent_instance = MagicMock()
|
||||
mock_agent_instance.run.return_value = MagicMock(
|
||||
is_failed=lambda: False, return_values={}
|
||||
)
|
||||
mock_agent.return_value = mock_agent_instance
|
||||
|
||||
# Call launch_agent
|
||||
launch_agent(config, entrypoint, args)
|
||||
|
||||
# Verify that the environment variable was set to the default value
|
||||
self.assertEqual(
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"],
|
||||
"SIGTERM,SIGINT,SIGHUP,SIGQUIT",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
90
test/distributed/test_run.py
Normal file
90
test/distributed/test_run.py
Normal file
@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch.distributed.run as run
|
||||
from torch.distributed.launcher.api import launch_agent, LaunchConfig
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class RunTest(TestCase):
|
||||
def setUp(self):
|
||||
# Save original environment variable if it exists
|
||||
self.original_signals_env = os.environ.get(
|
||||
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original environment variable
|
||||
if self.original_signals_env is not None:
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
|
||||
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
def test_signals_to_handle_default(self):
|
||||
"""Test that the default value for signals_to_handle is correctly set."""
|
||||
parser = run.get_args_parser()
|
||||
args = parser.parse_args(["dummy_script.py"])
|
||||
self.assertEqual(args.signals_to_handle, "SIGTERM,SIGINT,SIGHUP,SIGQUIT")
|
||||
|
||||
def test_signals_to_handle_custom(self):
|
||||
"""Test that a custom value for signals_to_handle is correctly parsed."""
|
||||
parser = run.get_args_parser()
|
||||
args = parser.parse_args(
|
||||
["--signals-to-handle=SIGTERM,SIGUSR1,SIGUSR2", "dummy_script.py"]
|
||||
)
|
||||
self.assertEqual(args.signals_to_handle, "SIGTERM,SIGUSR1,SIGUSR2")
|
||||
|
||||
def test_config_from_args_signals_to_handle(self):
|
||||
"""Test that the signals_to_handle argument is correctly passed to LaunchConfig."""
|
||||
parser = run.get_args_parser()
|
||||
args = parser.parse_args(
|
||||
["--signals-to-handle=SIGTERM,SIGUSR1,SIGUSR2", "dummy_script.py"]
|
||||
)
|
||||
config, _, _ = run.config_from_args(args)
|
||||
self.assertEqual(config.signals_to_handle, "SIGTERM,SIGUSR1,SIGUSR2")
|
||||
|
||||
@patch("torch.distributed.launcher.api.LocalElasticAgent")
|
||||
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
|
||||
def test_launch_agent_sets_environment_variable(self, mock_get_handler, mock_agent):
|
||||
"""Test that launch_agent sets the TORCHELASTIC_SIGNALS_TO_HANDLE environment variable."""
|
||||
# Setup
|
||||
config = LaunchConfig(
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
nproc_per_node=1,
|
||||
signals_to_handle="SIGTERM,SIGUSR1,SIGUSR2",
|
||||
)
|
||||
entrypoint = "dummy_script.py"
|
||||
args = []
|
||||
|
||||
# Make sure the environment variable doesn't exist before the test
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
|
||||
|
||||
# Mock agent.run() to return a MagicMock
|
||||
mock_agent_instance = MagicMock()
|
||||
mock_agent_instance.run.return_value = MagicMock(
|
||||
is_failed=lambda: False, return_values={}
|
||||
)
|
||||
mock_agent.return_value = mock_agent_instance
|
||||
|
||||
# Call launch_agent
|
||||
launch_agent(config, entrypoint, args)
|
||||
|
||||
# Verify that the environment variable was set correctly
|
||||
self.assertEqual(
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"], "SIGTERM,SIGUSR1,SIGUSR2"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -477,11 +477,35 @@ class PContext(abc.ABC):
|
||||
def start(self) -> None:
|
||||
"""Start processes using parameters defined in the constructor."""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGTERM, _terminate_process_handler)
|
||||
signal.signal(signal.SIGINT, _terminate_process_handler)
|
||||
if not IS_WINDOWS:
|
||||
signal.signal(signal.SIGHUP, _terminate_process_handler)
|
||||
signal.signal(signal.SIGQUIT, _terminate_process_handler)
|
||||
# Register signal handlers for the signals specified in the environment variable
|
||||
signals_to_handle = os.environ.get(
|
||||
"TORCHELASTIC_SIGNALS_TO_HANDLE", "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
|
||||
)
|
||||
signal_list = signals_to_handle.split(",")
|
||||
|
||||
for sig_name in signal_list:
|
||||
try:
|
||||
sig = getattr(signal, sig_name.strip())
|
||||
signal.signal(sig, _terminate_process_handler)
|
||||
logger.info("Registered signal handler for %s", sig_name)
|
||||
except (AttributeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"Failed to register signal handler for %s: %s", sig_name, e
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if IS_WINDOWS and sig_name.strip() in [
|
||||
"SIGHUP",
|
||||
"SIGQUIT",
|
||||
"SIGUSR1",
|
||||
"SIGUSR2",
|
||||
]:
|
||||
logger.info(
|
||||
"Signal %s is not supported on Windows, skipping", sig_name
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to register signal handler for %s: %s", sig_name, e
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to register signal handlers since torchelastic is running on a child thread. "
|
||||
|
@ -6,6 +6,7 @@
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@ -95,6 +96,7 @@ class LaunchConfig:
|
||||
local_addr: Optional[str] = None
|
||||
event_log_handler: str = "null"
|
||||
numa_options: Optional[NumaOptions] = None
|
||||
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
|
||||
|
||||
def __post_init__(self):
|
||||
default_timeout = 900
|
||||
@ -240,6 +242,7 @@ def launch_agent(
|
||||
"metrics_cfg": config.metrics_cfg,
|
||||
"event_log_handler": config.event_log_handler,
|
||||
"numa_options": config.numa_options,
|
||||
"signals_to_handle": config.signals_to_handle,
|
||||
},
|
||||
)
|
||||
|
||||
@ -255,6 +258,9 @@ def launch_agent(
|
||||
|
||||
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
|
||||
|
||||
# Set the signals to handle in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = config.signals_to_handle
|
||||
|
||||
spec = WorkerSpec(
|
||||
role=config.role,
|
||||
local_world_size=config.nproc_per_node,
|
||||
|
@ -645,6 +645,17 @@ def get_args_parser() -> ArgumentParser:
|
||||
featuring a single L3 cache per socket.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--signals-to-handle",
|
||||
"--signals_to_handle",
|
||||
action=env,
|
||||
type=str,
|
||||
default="SIGTERM,SIGINT,SIGHUP,SIGQUIT",
|
||||
help="Comma-separated list of signals to handle and forward to subprocesses. "
|
||||
"Default: SIGTERM,SIGINT,SIGHUP,SIGQUIT. "
|
||||
"Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).",
|
||||
)
|
||||
|
||||
#
|
||||
# Positional arguments.
|
||||
#
|
||||
@ -861,6 +872,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
||||
logs_specs=logs_specs,
|
||||
event_log_handler=args.event_log_handler,
|
||||
numa_options=numa_options,
|
||||
signals_to_handle=args.signals_to_handle,
|
||||
)
|
||||
|
||||
with_python = not args.no_python
|
||||
|
Reference in New Issue
Block a user