mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: If torch.multiprocessing.spawn is used to launch non-daemonic processes (the default since #14391), the spawned children won't be automatically terminated when the parent terminates. On Linux, we can address this by setting PR_SET_PDEATHSIG, which delivers a configurable signal to child processes when their parent terminates. Fixes #14394. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14491 Differential Revision: D13270374 Pulled By: pietern fbshipit-source-id: 092c9d3c3cea2622c3766b467957bc27a1bd500c
169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import multiprocessing
|
|
import multiprocessing.connection
|
|
import signal
|
|
import sys
|
|
|
|
from . import _prctl_pr_set_pdeathsig
|
|
|
|
|
|
def _wrap(fn, i, args, error_queue):
|
|
# prctl(2) is a Linux specific system call.
|
|
# On other systems the following function call has no effect.
|
|
# This is set to ensure that non-daemonic child processes can
|
|
# terminate if their parent terminates before they do.
|
|
_prctl_pr_set_pdeathsig(signal.SIGINT)
|
|
|
|
try:
|
|
fn(i, *args)
|
|
except KeyboardInterrupt:
|
|
pass # SIGINT; Killed by parent, do nothing
|
|
except Exception:
|
|
# Propagate exception to parent process, keeping original traceback
|
|
import traceback
|
|
error_queue.put(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def _python_version_check():
|
|
if sys.version_info < (3, 4):
|
|
raise RuntimeError("Requires python 3.4 or higher to use "
|
|
"torch.multiprocessing.spawn and "
|
|
"torch.multiprocessing.SpawnContext helper "
|
|
"to launch multiple processes. If you are using "
|
|
"this for distributed training and have a lower "
|
|
"version of python, please use "
|
|
"torch.distributed.launch instead.")
|
|
|
|
|
|
class SpawnContext:
|
|
def __init__(self, processes, error_queues):
|
|
_python_version_check()
|
|
self.error_queues = error_queues
|
|
self.processes = processes
|
|
self.sentinels = {
|
|
process.sentinel: index
|
|
for index, process in enumerate(processes)
|
|
}
|
|
|
|
def pids(self):
|
|
return [int(process.pid) for process in self.processes]
|
|
|
|
def join(self, timeout=None):
|
|
r"""
|
|
Tries to join one or more processes in this spawn context.
|
|
If one of them exited with a non-zero exit status, this function
|
|
kills the remaining processes and raises an exception with the cause
|
|
of the first process exiting.
|
|
|
|
Returns ``True`` if all processes have been joined successfully,
|
|
``False`` if there are more processes that need to be joined.
|
|
|
|
Arguments:
|
|
timeout (float): Wait this long before giving up on waiting.
|
|
"""
|
|
# Ensure this function can be called even when we're done.
|
|
if len(self.sentinels) == 0:
|
|
return True
|
|
|
|
# Wait for any process to fail or all of them to succeed.
|
|
ready = multiprocessing.connection.wait(
|
|
self.sentinels.keys(),
|
|
timeout=timeout,
|
|
)
|
|
|
|
error_index = None
|
|
for sentinel in ready:
|
|
index = self.sentinels.pop(sentinel)
|
|
process = self.processes[index]
|
|
process.join()
|
|
if process.exitcode != 0:
|
|
error_index = index
|
|
break
|
|
|
|
# Return if there was no error.
|
|
if error_index is None:
|
|
# Return whether or not all processes have been joined.
|
|
return len(self.sentinels) == 0
|
|
|
|
# Assume failure. Terminate processes that are still alive.
|
|
for process in self.processes:
|
|
if process.is_alive():
|
|
process.terminate()
|
|
process.join()
|
|
|
|
# There won't be an error on the queue if the process crashed.
|
|
if self.error_queues[error_index].empty():
|
|
exitcode = self.processes[error_index].exitcode
|
|
if exitcode < 0:
|
|
name = signal.Signals(-exitcode).name
|
|
raise Exception(
|
|
"process %d terminated with signal %s" %
|
|
(error_index, name)
|
|
)
|
|
else:
|
|
raise Exception(
|
|
"process %d terminated with exit code %d" %
|
|
(error_index, exitcode)
|
|
)
|
|
|
|
original_trace = self.error_queues[error_index].get()
|
|
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
|
|
msg += original_trace
|
|
raise Exception(msg)
|
|
|
|
|
|
def spawn(fn, args=(), nprocs=1, join=True, daemon=False):
|
|
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
|
|
|
|
If one of the processes exits with a non-zero exit status, the
|
|
remaining processes are killed and an exception is raised with the
|
|
cause of termination. In the case an exception was caught in the
|
|
child process, it is forwarded and its traceback is included in
|
|
the exception raised in the parent process.
|
|
|
|
Arguments:
|
|
fn (function): Function is called as the entrypoint of the
|
|
spawned process. This function must be defined at the top
|
|
level of a module so it can be pickled and spawned. This
|
|
is a requirement imposed by multiprocessing.
|
|
|
|
The function is called as ``fn(i, *args)``, where ``i`` is
|
|
the process index and ``args`` is the passed through tuple
|
|
of arguments.
|
|
|
|
args (tuple): Arguments passed to ``fn``.
|
|
nprocs (int): Number of processes to spawn.
|
|
join (bool): Perform a blocking join on all processes.
|
|
daemon (bool): The spawned processes' daemon flag. If set to True,
|
|
daemonic processes will be created.
|
|
|
|
Returns:
|
|
None if ``join`` is ``True``,
|
|
:class:`~SpawnContext` if ``join`` is ``False``
|
|
|
|
"""
|
|
_python_version_check()
|
|
mp = multiprocessing.get_context('spawn')
|
|
error_queues = []
|
|
processes = []
|
|
for i in range(nprocs):
|
|
error_queue = mp.SimpleQueue()
|
|
process = mp.Process(
|
|
target=_wrap,
|
|
args=(fn, i, args, error_queue),
|
|
daemon=daemon,
|
|
)
|
|
process.start()
|
|
error_queues.append(error_queue)
|
|
processes.append(process)
|
|
|
|
spawn_context = SpawnContext(processes, error_queues)
|
|
if not join:
|
|
return spawn_context
|
|
|
|
# Loop on join until it returns True or raises an exception.
|
|
while not spawn_context.join():
|
|
pass
|