Create processes in parallel in mp.start_processes for forkserver (#134629)

Summary:
This is to fix the pytorch issue filed https://github.com/pytorch/pytorch/issues/133010
one way to fix this problem is to enable parallel start processes in mp.start_processes.
What else in the diff:
refactored a test case api_test which was repeating a lot of tests due to the inheritance.
added unit test for forkserver when parallel start is on.

Test Plan: Added unit tests

Differential Revision: D61878552

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134629
Approved by: https://github.com/d4l3k
This commit is contained in:
Jia Li
2024-08-28 21:34:32 +00:00
committed by PyTorch MergeBot
parent f685018ea9
commit 20b62fed21
4 changed files with 281 additions and 140 deletions

View File

@ -226,36 +226,65 @@ def start_processes_zombie_test(
pc.close(e.sigval)
class _StartProcessesTest(TestCase):
def setUp(self):
super().setUp()
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
self._start_methods = ["spawn"]
def tearDown(self):
super().tearDown()
shutil.rmtree(self.test_dir)
def log_dir(self):
return tempfile.mkdtemp(dir=self.test_dir)
def assert_in_file(self, expected: List[str], filename: str) -> None:
expected = [f"{line.rstrip()}\n" for line in expected]
with open(filename) as fp:
actual = fp.readlines()
for line in expected:
self.assertIn(line, actual)
def assert_pids_noexist(self, pids: Dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
):
os.kill(pid, 0)
def _test_zombie_workflow(
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
) -> None:
mp_queue = mp.get_context("spawn").Queue()
child_nproc = 2
ctx = mp.spawn(
start_processes_zombie_test,
nprocs=1,
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
join=False,
)
total_processes = child_nproc + 1
pids = []
for _ in range(total_processes):
pids.append(mp_queue.get(timeout=120))
parent_pid = pids[0]
child_pids = pids[1:]
os.kill(parent_pid, signal.SIGTERM)
# Wait to give time for signal handlers to finish work
time.sleep(5)
for child_pid in child_pids:
# Killing parent should kill all children, we expect that each call to
# os.kill would raise OSError
with self.assertRaises(OSError):
os.kill(child_pid, 0)
# tests incompatible with tsan or asan
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
class StartProcessesTest(TestCase):
def setUp(self):
super().setUp()
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
self._start_methods = ["spawn"]
def tearDown(self):
super().tearDown()
shutil.rmtree(self.test_dir)
def log_dir(self):
return tempfile.mkdtemp(dir=self.test_dir)
def assert_in_file(self, expected: List[str], filename: str) -> None:
expected = [f"{line.rstrip()}\n" for line in expected]
with open(filename) as fp:
actual = fp.readlines()
for line in expected:
self.assertIn(line, actual)
def assert_pids_noexist(self, pids: Dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
):
os.kill(pid, 0)
class StartProcessesAsFuncTest(_StartProcessesTest):
def test_to_map(self):
local_world_size = 2
self.assertEqual(
@ -349,26 +378,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
def test_subprocess_context_close(self):
pc = start_processes(
name="sleep",
entrypoint=bin("zombie_test.py"),
args={0: (1,)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
def test_function_with_tensor(self):
for start_method in self._start_methods:
pc = start_processes(
name="dummy_compute",
entrypoint=dummy_compute,
args={},
envs={},
args={0: ()},
envs={0: {}},
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
start_method=start_method,
)
@ -489,10 +505,55 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
mpc._poll()
self.assertEqual(4, mock_join.call_count)
@skip_but_pass_in_sandcastle_if(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \
don't support multiprocessing with spawn start method",
)
def test_multiprocessing_context_poll_raises_exception(self):
mp_context = MultiprocessContext(
name="test_mp",
entrypoint=echo0,
args={0: (0, 1)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
),
start_method="spawn",
)
mp_context._pc = mock.Mock()
# Using mock since we cannot just set exitcode on process
mock_process = mock.Mock()
mock_process.exitcode = -1
mp_context._pc.processes = [mock_process]
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
mp_context._pc.join.side_effect = e
with mock.patch.object(mp_context, "close"):
run_result = mp_context._poll()
self.assertEqual(1, len(run_result.failures))
failure = run_result.failures[0]
self.assertEqual(
"Signal 1 (SIGHUP) received by PID 123", failure.message
)
class StartProcessesAsBinaryTest(_StartProcessesTest):
########################################
# start_processes as binary tests
########################################
def test_subprocess_context_close(self):
pc = start_processes(
name="sleep",
entrypoint=bin("zombie_test.py"),
args={0: (1,)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
def test_binary_exit(self):
FAIL = 138
pc = start_processes(
@ -556,45 +617,11 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
with self.assertRaises(RuntimeError):
_validate_full_rank({}, 10, "")
@skip_but_pass_in_sandcastle_if(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \
don't support multiprocessing with spawn start method",
)
def test_multiprocessing_context_poll_raises_exception(self):
mp_context = MultiprocessContext(
name="test_mp",
entrypoint=echo0,
args={0: (0, 1)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
),
start_method="spawn",
)
mp_context._pc = mock.Mock()
# Using mock since we cannot just set exitcode on process
mock_process = mock.Mock()
mock_process.exitcode = -1
mp_context._pc.processes = [mock_process]
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
mp_context._pc.join.side_effect = e
with mock.patch.object(mp_context, "close"):
run_result = mp_context._poll()
self.assertEqual(1, len(run_result.failures))
failure = run_result.failures[0]
self.assertEqual(
"Signal 1 (SIGHUP) received by PID 123", failure.message
)
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
class StartProcessesListTest(StartProcessesTest):
########################################
# start_processes as binary tests
########################################
class StartProcessesListAsFuncTest(_StartProcessesTest):
def test_function(self):
for start_method, redirs in product(
self._start_methods, redirects_oss_test()
@ -634,6 +661,10 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
[f"hello stderr from {i}"], results.stderrs[i]
)
class StartProcessesListAsBinaryTest(_StartProcessesTest):
########################################
# start_processes as binary tests
########################################
def test_binary(self):
for redirs in redirects_oss_test():
with self.subTest(redirs=redirs):
@ -700,7 +731,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
class StartProcessesNotCITest(StartProcessesTest):
class StartProcessesNotCIAsFuncTest(_StartProcessesTest):
@skip_if_pytest
def test_wrap_bad(self):
none = ""
@ -733,32 +764,6 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assert_in_file(["hello stderr from 0"], stderr_log)
worker_finished_event_mock.wait.assert_called_once()
def test_binary_signal(self):
pc = start_processes(
name="echo",
entrypoint=bin("echo3.py"),
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
),
)
results = pc.wait(period=0.1)
self.assert_pids_noexist(pc.pids())
self.assertTrue(results.is_failed())
self.assertEqual(1, len(results.failures))
failure = results.failures[0]
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
if TEST_WITH_ASAN or TEST_WITH_TSAN:
# ASAN/TSAN exit code is 1.
self.assertEqual("<N/A>", failure.signal_name())
else:
self.assertEqual("SIGSEGV", failure.signal_name())
self.assertEqual("<NONE>", failure.error_file_data["message"])
def test_function_redirect_and_tee(self):
for start_method in self._start_methods:
with self.subTest(start_method=start_method):
@ -871,42 +876,60 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
def test_no_zombie_process_binary(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals:
self._test_zombie_workflow(bin("zombie_test.py"), s)
def test_no_zombie_process_function(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals:
self._test_zombie_workflow(wait_fn, s)
def _test_zombie_workflow(
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
) -> None:
mp_queue = mp.get_context("spawn").Queue()
child_nproc = 2
ctx = mp.spawn(
start_processes_zombie_test,
nprocs=1,
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
join=False,
class StartProcessesNotCIAsBinaryTest(_StartProcessesTest):
def test_binary_signal(self):
pc = start_processes(
name="echo",
entrypoint=bin("echo3.py"),
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
),
)
total_processes = child_nproc + 1
pids = []
for _ in range(total_processes):
pids.append(mp_queue.get(timeout=120))
parent_pid = pids[0]
child_pids = pids[1:]
os.kill(parent_pid, signal.SIGTERM)
# Wait to give time for signal handlers to finish work
time.sleep(5)
for child_pid in child_pids:
# Killing parent should kill all children, we expect that each call to
# os.kill would raise OSError
with self.assertRaises(OSError):
os.kill(child_pid, 0)
results = pc.wait(period=0.1)
self.assert_pids_noexist(pc.pids())
self.assertTrue(results.is_failed())
self.assertEqual(1, len(results.failures))
failure = results.failures[0]
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
if TEST_WITH_ASAN or TEST_WITH_TSAN:
# ASAN/TSAN exit code is 1.
self.assertEqual("<N/A>", failure.signal_name())
else:
self.assertEqual("SIGSEGV", failure.signal_name())
self.assertEqual("<NONE>", failure.error_file_data["message"])
def test_no_zombie_process_binary(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals:
self._test_zombie_workflow(bin("zombie_test.py"), s)
class ForkServerTest(
StartProcessesAsFuncTest,
StartProcessesListAsFuncTest,
StartProcessesNotCIAsFuncTest,
):
def setUp(self):
super().setUp()
self._start_methods = ["forkserver"]
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
def tearDown(self):
super().tearDown()
if self.orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
if __name__ == "__main__":

View File

@ -8,9 +8,14 @@ import sys
import time
import unittest
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import (
IS_WINDOWS,
NO_MULTIPROCESSING_SPAWN,
run_tests,
TestCase,
)
def _test_success_func(i):
pass
@ -220,6 +225,73 @@ class ForkTest(TestCase, _TestMultiProcessing):
start_method = 'fork'
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
orig_paralell_env_val = None
def setUp(self):
super().setUp()
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
def tearDown(self):
super().tearDown()
if self.orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerPerfTest(TestCase):
def test_forkserver_perf(self):
start_method = 'forkserver'
expensive = Expensive()
nprocs = 4
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
# test the non parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the elapsed time should be at least {nprocs}x the sleep time
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
# test the parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the elapsed time should be less than {nprocs}x the sleep time
self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs)
if orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
class Expensive:
SLEEP_SECS = 5
# Simulate startup overhead such as large imports
time.sleep(SLEEP_SECS)
def __init__(self):
self.config: str = "*" * 1000000
def my_call(self, *args):
pass
class ErrorTest(TestCase):
def test_errors_pickleable(self):
for error in (

View File

@ -39,6 +39,7 @@ torch._C._multiprocessing_init()
"""Add helper function to spawn N processes and wait for completion of any of
them. This depends `mp.get_context` which was added in Python 3.4."""
from .spawn import (
ENV_VAR_PARALLEL_START,
ProcessContext,
ProcessExitedException,
ProcessRaisedException,

View File

@ -9,13 +9,26 @@ import sys
import tempfile
import time
import warnings
from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
log = logging.getLogger(__name__)
__all__ = [
"ProcessContext",
"ProcessException",
"ProcessExitedException",
"ProcessRaisedException",
"spawn",
"SpawnContext",
"start_processes",
]
class ProcessException(Exception):
__slots__ = ["error_index", "error_pid"]
@ -205,12 +218,32 @@ class SpawnContext(ProcessContext):
# Currently we only add this API first, we can consider adding it to documentation as
# needed in the future.
def start_processes(
fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"
fn,
args=(),
nprocs=1,
join=True,
daemon=False,
start_method="spawn",
):
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
# this func will start processes in parallel if start_method is 'forkserver'.
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
# todo: investigate why spawn does not work with threadpool and raises SIGINT
if (
start_method == "forkserver"
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
):
log.info("Starting processes in parallel.")
start_parallel = True
else:
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
start_parallel = False
mp = multiprocessing.get_context(start_method)
error_files = []
processes = []
for i in range(nprocs):
error_files = [None] * nprocs
processes = [None] * nprocs
def start_process(i):
# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
@ -228,9 +261,21 @@ def start_processes(
daemon=daemon,
)
process.start()
error_files.append(tf.name)
processes.append(process)
return i, process, tf.name
if not start_parallel:
for i in range(nprocs):
idx, process, tf_name = start_process(i)
error_files[idx] = tf_name
processes[idx] = process
else:
with ThreadPoolExecutor(max_workers=nprocs) as executor:
futures = [executor.submit(start_process, i) for i in range(nprocs)]
for fut in as_completed(futures):
idx, process, tf_name = fut.result()
# idx and process rank needs to be the same.
error_files[idx] = tf_name
processes[idx] = process
context = ProcessContext(processes, error_files)
if not join:
return context