Add torch.multiprocessing.spawn helper (#13518)

Summary:
This helper addresses a common pattern where one spawns N processes to
work on some common task (e.g. parallel preprocessing or multiple
training loops).

A straightforward approach is to use the multiprocessing API directly
and then consecutively call join on the resulting processes.

This pattern breaks down in the face of errors. If one of the
processes terminates with an exception or via some signal, and it is
not the first process that was launched, the join call on the first
process won't be affected. This helper seeks to solve this by waiting
on termination from any of the spawned processes. When any process
terminates with a non-zero exit status, it terminates the remaining
processes, and raises an exception in the parent process. If the
process terminated with an exception, it is propagated to the parent.
If the process terminated via a signal (e.g. SIGINT, SIGSEGV), this is
mentioned in the exception as well.

Requires Python >= 3.4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13518

Reviewed By: orionr

Differential Revision: D12929045

Pulled By: pietern

fbshipit-source-id: 00df19fa16a568d1e22f37a2ba65677ab0cce3fd
This commit is contained in:
Pieter Noordhuis
2018-11-06 14:06:33 -08:00
committed by Facebook Github Bot
parent 056f2cd238
commit be424de869
4 changed files with 268 additions and 0 deletions

View File

@ -0,0 +1,123 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import random
import signal
import sys
import time
import unittest
from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
import torch.multiprocessing as mp
def test_success_func(i):
pass
def test_success_single_arg_func(i, arg):
if arg:
arg.put(i)
def test_exception_single_func(i, arg):
if i == arg:
raise ValueError("legitimate exception from process %d" % i)
time.sleep(1.0)
def test_exception_all_func(i):
time.sleep(random.random() / 10)
raise ValueError("legitimate exception from process %d" % i)
def test_terminate_signal_func(i):
if i == 0:
os.kill(os.getpid(), signal.SIGABRT)
time.sleep(1.0)
def test_terminate_exit_func(i, arg):
if i == 0:
sys.exit(arg)
time.sleep(1.0)
def test_success_first_then_exception_func(i, arg):
if i == 0:
return
time.sleep(0.1)
raise ValueError("legitimate exception")
@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
class SpawnTest(TestCase):
def test_success(self):
mp.spawn(test_success_func, nprocs=2)
def test_success_non_blocking(self):
spawn_context = mp.spawn(test_success_func, nprocs=2, join=False)
# After all processes (nproc=2) have joined it must return True
spawn_context.join(timeout=None)
spawn_context.join(timeout=None)
self.assertTrue(spawn_context.join(timeout=None))
def test_first_argument_index(self):
context = mp.get_context("spawn")
queue = context.SimpleQueue()
mp.spawn(test_success_single_arg_func, args=(queue,), nprocs=2)
self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))
def test_exception_single(self):
nprocs = 2
for i in range(nprocs):
with self.assertRaisesRegex(
Exception,
"\nValueError: legitimate exception from process %d$" % i,
):
mp.spawn(test_exception_single_func, args=(i,), nprocs=nprocs)
def test_exception_all(self):
with self.assertRaisesRegex(
Exception,
"\nValueError: legitimate exception from process (0|1)$",
):
mp.spawn(test_exception_all_func, nprocs=2)
def test_terminate_signal(self):
# SIGABRT is aliased with SIGIOT
message = "process 0 terminated with signal (SIGABRT|SIGIOT)"
# Termination through with signal is expressed as a negative exit code
# in multiprocessing, so we know it was a signal that caused the exit.
# This doesn't appear to exist on Windows, where the exit code is always
# positive, and therefore results in a different exception message.
# Exit code 22 means "ERROR_BAD_COMMAND".
if IS_WINDOWS:
message = "process 0 terminated with exit code 22"
with self.assertRaisesRegex(Exception, message):
mp.spawn(test_terminate_signal_func, nprocs=2)
def test_terminate_exit(self):
exitcode = 123
with self.assertRaisesRegex(
Exception,
"process 0 terminated with exit code %d" % exitcode,
):
mp.spawn(test_terminate_exit_func, args=(exitcode,), nprocs=2)
def test_success_first_then_exception(self):
exitcode = 123
with self.assertRaisesRegex(
Exception,
"ValueError: legitimate exception",
):
mp.spawn(test_success_first_then_exception_func, args=(exitcode,), nprocs=2)
if __name__ == '__main__':
run_tests()