154849 Add support to handle IGUSR1 and SIGUSR2 in multiprocessing (#160690)

Fixes #154849

This change addresses the request to add support for SIGUSR1 and SIGUSR2 signals in torchrun for SLURM environments.  Changes supports these signals through the configurable `TORCHELASTIC_SIGNALS_TO_HANDLE` environment variable and signals_to_handle parameter from laucher api

Tests:
For validations purpose:
test_signal_handling.py,
simple_test_api_signal_handling.py,

Unit Tests:
for launcher changes:launcher/test_api.py
for api changes:  multiprocessing/test_api.py
E2E: test_run.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160690
Approved by: https://github.com/fduwjj
This commit is contained in:
SandishKumarHN
2025-09-09 22:23:02 +00:00
committed by PyTorch MergeBot
parent 4d66a3b894
commit b498299953
6 changed files with 568 additions and 5 deletions

View File

@ -0,0 +1,331 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import signal
from unittest.mock import MagicMock, patch
from torch.distributed.elastic.multiprocessing.api import (
_terminate_process_handler,
PContext,
SignalException,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class SignalHandlingTest(TestCase):
def setUp(self):
# Save original environment variable if it exists
self.original_signals_env = os.environ.get(
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
)
def tearDown(self):
# Restore original environment variable
if self.original_signals_env is not None:
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
def test_terminate_process_handler(self):
"""Test that the terminate process handler raises SignalException with the correct signal."""
signum = signal.SIGTERM
with self.assertRaises(SignalException) as cm:
_terminate_process_handler(signum, None)
self.assertEqual(cm.exception.sigval, signal.SIGTERM)
# The signal is represented as a number in the string representation
self.assertIn(f"Process {os.getpid()} got signal: {signum}", str(cm.exception))
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@patch("torch.distributed.elastic.multiprocessing.api.logger")
def test_start_registers_default_signals(
self, mock_logger, mock_signal, mock_threading
):
"""Test that the start method registers the default signals."""
# Setup
mock_threading.current_thread.return_value = (
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Remove environment variable if it exists to test default behavior
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
# Call the start method
PContext.start(mock_pcontext)
# Verify that the signal handler was registered for the default signals
expected_signals = ["SIGTERM", "SIGINT", "SIGHUP", "SIGQUIT"]
# Count the number of calls to signal.signal
signal_calls = 0
for call in mock_signal.signal.call_args_list:
args, _ = call
sig, handler = args
signal_calls += 1
# Verify the handler is our _terminate_process_handler
self.assertEqual(handler, _terminate_process_handler)
# Verify we registered the expected number of signals
self.assertEqual(signal_calls, len(expected_signals))
# Verify _start was called
mock_pcontext._start.assert_called_once()
# Verify _stdout_tail.start() and _stderr_tail.start() were called
mock_pcontext._stdout_tail.start.assert_called_once()
mock_pcontext._stderr_tail.start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@patch("torch.distributed.elastic.multiprocessing.api.logger")
def test_start_registers_custom_signals(
self, mock_logger, mock_signal, mock_threading
):
"""Test that the start method registers custom signals from the environment variable."""
# Setup
mock_threading.current_thread.return_value = (
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Set custom signals in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
# Call the start method
PContext.start(mock_pcontext)
# Verify that the signal handler was registered for the custom signals
expected_signals = ["SIGTERM", "SIGUSR1", "SIGUSR2"]
# Count the number of calls to signal.signal
signal_calls = 0
for call in mock_signal.signal.call_args_list:
args, _ = call
sig, handler = args
signal_calls += 1
# Verify the handler is our _terminate_process_handler
self.assertEqual(handler, _terminate_process_handler)
# Verify we registered the expected number of signals
self.assertEqual(signal_calls, len(expected_signals))
# Verify _start was called
mock_pcontext._start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@patch("torch.distributed.elastic.multiprocessing.api.logger")
def test_start_handles_invalid_signals(
self, mock_logger, mock_signal, mock_threading
):
"""Test that the start method handles invalid signals gracefully."""
# Setup
mock_threading.current_thread.return_value = (
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Set invalid signals in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
# Mock the signal module to not have the INVALID_SIGNAL attribute
# but have SIGTERM
mock_signal.SIGTERM = signal.SIGTERM
# Remove INVALID_SIGNAL attribute if it exists
if hasattr(mock_signal, "INVALID_SIGNAL"):
delattr(mock_signal, "INVALID_SIGNAL")
# Call the start method
PContext.start(mock_pcontext)
# Verify that the warning was logged for the invalid signal
# The exact message may vary, so let's check if warning was called with INVALID_SIGNAL
warning_calls = [
call
for call in mock_logger.warning.call_args_list
if "INVALID_SIGNAL" in str(call)
]
self.assertTrue(len(warning_calls) > 0, "Expected warning about INVALID_SIGNAL")
# Verify _start was called
mock_pcontext._start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@patch("torch.distributed.elastic.multiprocessing.api.logger")
def test_start_handles_windows_signals(
self, mock_logger, mock_signal, mock_threading
):
"""Test that the start method handles Windows-specific signal behavior."""
# Setup
mock_threading.current_thread.return_value = (
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Set signals including ones not supported on Windows
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
# Mock signal attributes
mock_signal.SIGTERM = signal.SIGTERM
mock_signal.SIGHUP = signal.SIGHUP
mock_signal.SIGUSR1 = signal.SIGUSR1
# Mock IS_WINDOWS to be True
with patch("torch.distributed.elastic.multiprocessing.api.IS_WINDOWS", True):
# Mock signal.signal to raise RuntimeError for Windows-unsupported signals
def signal_side_effect(sig, handler):
if sig in [signal.SIGHUP, signal.SIGUSR1]:
raise RuntimeError("Signal not supported on Windows")
mock_signal.signal.side_effect = signal_side_effect
# Call the start method
PContext.start(mock_pcontext)
# Verify that the info was logged for the unsupported signals
# Check if any info calls contain the expected messages
info_calls = [str(call) for call in mock_logger.info.call_args_list]
sighup_logged = any(
"SIGHUP" in call and "Windows" in call for call in info_calls
)
sigusr1_logged = any(
"SIGUSR1" in call and "Windows" in call for call in info_calls
)
self.assertTrue(
sighup_logged,
f"Expected SIGHUP Windows message in info calls: {info_calls}",
)
self.assertTrue(
sigusr1_logged,
f"Expected SIGUSR1 Windows message in info calls: {info_calls}",
)
# Verify _start was called
mock_pcontext._start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.logger")
def test_start_not_main_thread(self, mock_logger, mock_threading):
"""Test that the start method warns when not called from the main thread."""
# Setup
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
mock_threading.main_thread.return_value = MagicMock()
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Call the start method
PContext.start(mock_pcontext)
# Verify that the warning was logged
mock_logger.warning.assert_called_with(
"Failed to register signal handlers since torchelastic is running on a child thread. "
"This could lead to orphaned worker processes if the torchrun is terminated."
)
# Verify _start was called
mock_pcontext._start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@patch("torch.distributed.elastic.multiprocessing.api.logger")
def test_start_supports_sigusr1_and_sigusr2(
self, mock_logger, mock_signal, mock_threading
):
"""Test that the start method properly supports SIGUSR1 and SIGUSR2 signals."""
# Setup
mock_threading.current_thread.return_value = (
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Set environment variable to include SIGUSR1 and SIGUSR2
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
# Mock signal attributes to have SIGUSR1 and SIGUSR2
mock_signal.SIGUSR1 = signal.SIGUSR1
mock_signal.SIGUSR2 = signal.SIGUSR2
# Call the start method
PContext.start(mock_pcontext)
# Verify that signal.signal was called for both SIGUSR1 and SIGUSR2
signal_calls = mock_signal.signal.call_args_list
registered_signals = [
call[0][0] for call in signal_calls
] # Extract the signal from each call
# Verify both SIGUSR1 and SIGUSR2 were registered
self.assertIn(
signal.SIGUSR1, registered_signals, "SIGUSR1 should be registered"
)
self.assertIn(
signal.SIGUSR2, registered_signals, "SIGUSR2 should be registered"
)
# Verify the correct handler was registered for both signals
for call in signal_calls:
sig, handler = call[0]
if sig in [signal.SIGUSR1, signal.SIGUSR2]:
self.assertEqual(
handler,
_terminate_process_handler,
f"Signal {sig} should use _terminate_process_handler",
)
# Verify that info messages were logged for successful registration
info_calls = [str(call) for call in mock_logger.info.call_args_list]
sigusr1_logged = any(
"SIGUSR1" in call and "Registered signal handler" in call
for call in info_calls
)
sigusr2_logged = any(
"SIGUSR2" in call and "Registered signal handler" in call
for call in info_calls
)
self.assertTrue(
sigusr1_logged,
f"Expected SIGUSR1 registration message in info calls: {info_calls}",
)
self.assertTrue(
sigusr2_logged,
f"Expected SIGUSR2 registration message in info calls: {info_calls}",
)
# Verify _start was called
mock_pcontext._start.assert_called_once()
# Verify _stdout_tail.start() and _stderr_tail.start() were called
mock_pcontext._stdout_tail.start.assert_called_once()
mock_pcontext._stderr_tail.start.assert_called_once()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from unittest.mock import MagicMock, patch
from torch.distributed.launcher.api import launch_agent, LaunchConfig
from torch.testing._internal.common_utils import run_tests, TestCase
class LauncherApiTest(TestCase):
def setUp(self):
# Save original environment variable if it exists
self.original_signals_env = os.environ.get(
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
)
def tearDown(self):
# Restore original environment variable
if self.original_signals_env is not None:
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
@patch("torch.distributed.launcher.api.LocalElasticAgent")
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
def test_launch_agent_sets_signals_env_var(self, mock_get_handler, mock_agent):
"""Test that launch_agent sets the TORCHELASTIC_SIGNALS_TO_HANDLE environment variable."""
# Setup
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=1,
signals_to_handle="SIGTERM,SIGUSR1,SIGUSR2",
)
entrypoint = "dummy_script.py"
args = []
# Make sure the environment variable doesn't exist before the test
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
# Mock agent.run() to return a MagicMock
mock_agent_instance = MagicMock()
mock_agent_instance.run.return_value = MagicMock(
is_failed=lambda: False, return_values={}
)
mock_agent.return_value = mock_agent_instance
# Call launch_agent
launch_agent(config, entrypoint, args)
# Verify that the environment variable was set correctly
self.assertEqual(
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"], "SIGTERM,SIGUSR1,SIGUSR2"
)
@patch("torch.distributed.launcher.api.LocalElasticAgent")
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
def test_launch_agent_default_signals(self, mock_get_handler, mock_agent):
"""Test that launch_agent uses the default signals if not specified."""
# Setup
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=1,
# Not specifying signals_to_handle, should use default
)
entrypoint = "dummy_script.py"
args = []
# Make sure the environment variable doesn't exist before the test
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
# Mock agent.run() to return a MagicMock
mock_agent_instance = MagicMock()
mock_agent_instance.run.return_value = MagicMock(
is_failed=lambda: False, return_values={}
)
mock_agent.return_value = mock_agent_instance
# Call launch_agent
launch_agent(config, entrypoint, args)
# Verify that the environment variable was set to the default value
self.assertEqual(
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"],
"SIGTERM,SIGINT,SIGHUP,SIGQUIT",
)
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,90 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from unittest.mock import MagicMock, patch
import torch.distributed.run as run
from torch.distributed.launcher.api import launch_agent, LaunchConfig
from torch.testing._internal.common_utils import run_tests, TestCase
class RunTest(TestCase):
def setUp(self):
# Save original environment variable if it exists
self.original_signals_env = os.environ.get(
"TORCHELASTIC_SIGNALS_TO_HANDLE", None
)
def tearDown(self):
# Restore original environment variable
if self.original_signals_env is not None:
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = self.original_signals_env
elif "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
def test_signals_to_handle_default(self):
"""Test that the default value for signals_to_handle is correctly set."""
parser = run.get_args_parser()
args = parser.parse_args(["dummy_script.py"])
self.assertEqual(args.signals_to_handle, "SIGTERM,SIGINT,SIGHUP,SIGQUIT")
def test_signals_to_handle_custom(self):
"""Test that a custom value for signals_to_handle is correctly parsed."""
parser = run.get_args_parser()
args = parser.parse_args(
["--signals-to-handle=SIGTERM,SIGUSR1,SIGUSR2", "dummy_script.py"]
)
self.assertEqual(args.signals_to_handle, "SIGTERM,SIGUSR1,SIGUSR2")
def test_config_from_args_signals_to_handle(self):
"""Test that the signals_to_handle argument is correctly passed to LaunchConfig."""
parser = run.get_args_parser()
args = parser.parse_args(
["--signals-to-handle=SIGTERM,SIGUSR1,SIGUSR2", "dummy_script.py"]
)
config, _, _ = run.config_from_args(args)
self.assertEqual(config.signals_to_handle, "SIGTERM,SIGUSR1,SIGUSR2")
@patch("torch.distributed.launcher.api.LocalElasticAgent")
@patch("torch.distributed.launcher.api.rdzv_registry.get_rendezvous_handler")
def test_launch_agent_sets_environment_variable(self, mock_get_handler, mock_agent):
"""Test that launch_agent sets the TORCHELASTIC_SIGNALS_TO_HANDLE environment variable."""
# Setup
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=1,
signals_to_handle="SIGTERM,SIGUSR1,SIGUSR2",
)
entrypoint = "dummy_script.py"
args = []
# Make sure the environment variable doesn't exist before the test
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
del os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"]
# Mock agent.run() to return a MagicMock
mock_agent_instance = MagicMock()
mock_agent_instance.run.return_value = MagicMock(
is_failed=lambda: False, return_values={}
)
mock_agent.return_value = mock_agent_instance
# Call launch_agent
launch_agent(config, entrypoint, args)
# Verify that the environment variable was set correctly
self.assertEqual(
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"], "SIGTERM,SIGUSR1,SIGUSR2"
)
if __name__ == "__main__":
run_tests()

View File

@ -477,11 +477,35 @@ class PContext(abc.ABC):
def start(self) -> None:
"""Start processes using parameters defined in the constructor."""
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
# Register signal handlers for the signals specified in the environment variable
signals_to_handle = os.environ.get(
"TORCHELASTIC_SIGNALS_TO_HANDLE", "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
)
signal_list = signals_to_handle.split(",")
for sig_name in signal_list:
try:
sig = getattr(signal, sig_name.strip())
signal.signal(sig, _terminate_process_handler)
logger.info("Registered signal handler for %s", sig_name)
except (AttributeError, ValueError) as e:
logger.warning(
"Failed to register signal handler for %s: %s", sig_name, e
)
except RuntimeError as e:
if IS_WINDOWS and sig_name.strip() in [
"SIGHUP",
"SIGQUIT",
"SIGUSR1",
"SIGUSR2",
]:
logger.info(
"Signal %s is not supported on Windows, skipping", sig_name
)
else:
logger.warning(
"Failed to register signal handler for %s: %s", sig_name, e
)
else:
logger.warning(
"Failed to register signal handlers since torchelastic is running on a child thread. "

View File

@ -6,6 +6,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import uuid
from dataclasses import dataclass, field
@ -95,6 +96,7 @@ class LaunchConfig:
local_addr: Optional[str] = None
event_log_handler: str = "null"
numa_options: Optional[NumaOptions] = None
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
def __post_init__(self):
default_timeout = 900
@ -240,6 +242,7 @@ def launch_agent(
"metrics_cfg": config.metrics_cfg,
"event_log_handler": config.event_log_handler,
"numa_options": config.numa_options,
"signals_to_handle": config.signals_to_handle,
},
)
@ -255,6 +258,9 @@ def launch_agent(
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
# Set the signals to handle in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = config.signals_to_handle
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,

View File

@ -645,6 +645,17 @@ def get_args_parser() -> ArgumentParser:
featuring a single L3 cache per socket.""",
)
parser.add_argument(
"--signals-to-handle",
"--signals_to_handle",
action=env,
type=str,
default="SIGTERM,SIGINT,SIGHUP,SIGQUIT",
help="Comma-separated list of signals to handle and forward to subprocesses. "
"Default: SIGTERM,SIGINT,SIGHUP,SIGQUIT. "
"Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).",
)
#
# Positional arguments.
#
@ -861,6 +872,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
logs_specs=logs_specs,
event_log_handler=args.event_log_handler,
numa_options=numa_options,
signals_to_handle=args.signals_to_handle,
)
with_python = not args.no_python