[torchelastic] Don't do signal handling when off the main thread (#135088)

Summary:
In multiprocessing, signal handling is not possible if the thread is not the main thread. This resulted in the following error:
> "ValueError('signal only works in main thread of the main interpreter')"

To address this issue, the diff checks whether the thread is the main thread and, if not, skips signal handling.

Test Plan:
Before this change, MAST job failed:
https://fburl.com/mlhub/iq2m10v8

With this change, MAST job succeeded:
https://fburl.com/mlhub/q6kb8343

Differential Revision: D62166943

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135088
Approved by: https://github.com/d4l3k
This commit is contained in:
Yiwen Shi
2024-09-06 14:47:03 +00:00
committed by PyTorch MergeBot
parent a086882d72
commit 3a9e33dca8
2 changed files with 16 additions and 5 deletions

View File

@ -16,6 +16,7 @@ import signal
import subprocess
import sys
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
@ -470,11 +471,17 @@ class PContext(abc.ABC):
def start(self) -> None:
"""Start processes using parameters defined in the constructor."""
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
else:
logger.warning(
"Failed to register signal handlers since torchelastic is running on a child thread. "
"This could lead to orphaned worker processes if the torchrun is terminated."
)
self._start()
self._stdout_tail.start()
self._stderr_tail.start()