Revert "Allow mp.start_processes to create processes in parallel (#133707)"

This reverts commit 3546628a2a167ace6060737eeccf8ee8fd87ddc0.

Reverted https://github.com/pytorch/pytorch/pull/133707 on behalf of https://github.com/ZainRizvi due to sorry but trunk has been consistently broken since this PR was merged. See: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10529617600/job/29191757055) [HUD commit link](3546628a2a) ([comment](https://github.com/pytorch/pytorch/pull/133707#issuecomment-2310709523))
This commit is contained in:
PyTorch MergeBot
2024-08-26 17:31:10 +00:00
parent d0ac5d55ba
commit adcce538b7
4 changed files with 144 additions and 311 deletions

View File

@ -226,65 +226,36 @@ def start_processes_zombie_test(
pc.close(e.sigval) pc.close(e.sigval)
class _StartProcessesTest(TestCase):
def setUp(self):
super().setUp()
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
self._start_methods = ["spawn"]
def tearDown(self):
super().tearDown()
shutil.rmtree(self.test_dir)
def log_dir(self):
return tempfile.mkdtemp(dir=self.test_dir)
def assert_in_file(self, expected: List[str], filename: str) -> None:
expected = [f"{line.rstrip()}\n" for line in expected]
with open(filename) as fp:
actual = fp.readlines()
for line in expected:
self.assertIn(line, actual)
def assert_pids_noexist(self, pids: Dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
):
os.kill(pid, 0)
def _test_zombie_workflow(
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
) -> None:
mp_queue = mp.get_context("spawn").Queue()
child_nproc = 2
ctx = mp.spawn(
start_processes_zombie_test,
nprocs=1,
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
join=False,
)
total_processes = child_nproc + 1
pids = []
for _ in range(total_processes):
pids.append(mp_queue.get(timeout=120))
parent_pid = pids[0]
child_pids = pids[1:]
os.kill(parent_pid, signal.SIGTERM)
# Wait to give time for signal handlers to finish work
time.sleep(5)
for child_pid in child_pids:
# Killing parent should kill all children, we expect that each call to
# os.kill would raise OSError
with self.assertRaises(OSError):
os.kill(child_pid, 0)
# tests incompatible with tsan or asan # tests incompatible with tsan or asan
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
class StartProcessesAsFuncTest(_StartProcessesTest): class StartProcessesTest(TestCase):
def setUp(self):
super().setUp()
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
self._start_methods = ["spawn"]
def tearDown(self):
super().tearDown()
shutil.rmtree(self.test_dir)
def log_dir(self):
return tempfile.mkdtemp(dir=self.test_dir)
def assert_in_file(self, expected: List[str], filename: str) -> None:
expected = [f"{line.rstrip()}\n" for line in expected]
with open(filename) as fp:
actual = fp.readlines()
for line in expected:
self.assertIn(line, actual)
def assert_pids_noexist(self, pids: Dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
):
os.kill(pid, 0)
def test_to_map(self): def test_to_map(self):
local_world_size = 2 local_world_size = 2
self.assertEqual( self.assertEqual(
@ -378,13 +349,26 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_subprocess_context_close(self):
pc = start_processes(
name="sleep",
entrypoint=bin("zombie_test.py"),
args={0: (1,)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
def test_function_with_tensor(self): def test_function_with_tensor(self):
for start_method in self._start_methods: for start_method in self._start_methods:
pc = start_processes( pc = start_processes(
name="dummy_compute", name="dummy_compute",
entrypoint=dummy_compute, entrypoint=dummy_compute,
args={0: ()}, args={},
envs={0: {}}, envs={},
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
start_method=start_method, start_method=start_method,
) )
@ -505,55 +489,10 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
mpc._poll() mpc._poll()
self.assertEqual(4, mock_join.call_count) self.assertEqual(4, mock_join.call_count)
@skip_but_pass_in_sandcastle_if(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \
don't support multiprocessing with spawn start method",
)
def test_multiprocessing_context_poll_raises_exception(self):
mp_context = MultiprocessContext(
name="test_mp",
entrypoint=echo0,
args={0: (0, 1)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
),
start_method="spawn",
)
mp_context._pc = mock.Mock()
# Using mock since we cannot just set exitcode on process
mock_process = mock.Mock()
mock_process.exitcode = -1
mp_context._pc.processes = [mock_process]
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
mp_context._pc.join.side_effect = e
with mock.patch.object(mp_context, "close"):
run_result = mp_context._poll()
self.assertEqual(1, len(run_result.failures))
failure = run_result.failures[0]
self.assertEqual(
"Signal 1 (SIGHUP) received by PID 123", failure.message
)
class StartProcessesAsBinaryTest(_StartProcessesTest):
######################################## ########################################
# start_processes as binary tests # start_processes as binary tests
######################################## ########################################
def test_subprocess_context_close(self):
pc = start_processes(
name="sleep",
entrypoint=bin("zombie_test.py"),
args={0: (1,)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
def test_binary_exit(self): def test_binary_exit(self):
FAIL = 138 FAIL = 138
pc = start_processes( pc = start_processes(
@ -617,11 +556,45 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_validate_full_rank({}, 10, "") _validate_full_rank({}, 10, "")
@skip_but_pass_in_sandcastle_if(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \
don't support multiprocessing with spawn start method",
)
def test_multiprocessing_context_poll_raises_exception(self):
mp_context = MultiprocessContext(
name="test_mp",
entrypoint=echo0,
args={0: (0, 1)},
envs={0: {}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
),
start_method="spawn",
)
mp_context._pc = mock.Mock()
# Using mock since we cannot just set exitcode on process
mock_process = mock.Mock()
mock_process.exitcode = -1
mp_context._pc.processes = [mock_process]
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
mp_context._pc.join.side_effect = e
with mock.patch.object(mp_context, "close"):
run_result = mp_context._poll()
self.assertEqual(1, len(run_result.failures))
failure = run_result.failures[0]
self.assertEqual(
"Signal 1 (SIGHUP) received by PID 123", failure.message
)
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
class StartProcessesListAsFuncTest(_StartProcessesTest): class StartProcessesListTest(StartProcessesTest):
########################################
# start_processes as binary tests
########################################
def test_function(self): def test_function(self):
for start_method, redirs in product( for start_method, redirs in product(
self._start_methods, redirects_oss_test() self._start_methods, redirects_oss_test()
@ -661,10 +634,6 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
[f"hello stderr from {i}"], results.stderrs[i] [f"hello stderr from {i}"], results.stderrs[i]
) )
class StartProcessesListAsBinaryTest(_StartProcessesTest):
########################################
# start_processes as binary tests
########################################
def test_binary(self): def test_binary(self):
for redirs in redirects_oss_test(): for redirs in redirects_oss_test():
with self.subTest(redirs=redirs): with self.subTest(redirs=redirs):
@ -731,7 +700,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI): if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
class StartProcessesNotCIAsFuncTest(_StartProcessesTest): class StartProcessesNotCITest(StartProcessesTest):
@skip_if_pytest @skip_if_pytest
def test_wrap_bad(self): def test_wrap_bad(self):
none = "" none = ""
@ -764,6 +733,32 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assert_in_file(["hello stderr from 0"], stderr_log) self.assert_in_file(["hello stderr from 0"], stderr_log)
worker_finished_event_mock.wait.assert_called_once() worker_finished_event_mock.wait.assert_called_once()
def test_binary_signal(self):
pc = start_processes(
name="echo",
entrypoint=bin("echo3.py"),
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
),
)
results = pc.wait(period=0.1)
self.assert_pids_noexist(pc.pids())
self.assertTrue(results.is_failed())
self.assertEqual(1, len(results.failures))
failure = results.failures[0]
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
if TEST_WITH_ASAN or TEST_WITH_TSAN:
# ASAN/TSAN exit code is 1.
self.assertEqual("<N/A>", failure.signal_name())
else:
self.assertEqual("SIGSEGV", failure.signal_name())
self.assertEqual("<NONE>", failure.error_file_data["message"])
def test_function_redirect_and_tee(self): def test_function_redirect_and_tee(self):
for start_method in self._start_methods: for start_method in self._start_methods:
with self.subTest(start_method=start_method): with self.subTest(start_method=start_method):
@ -876,60 +871,42 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped())
def test_no_zombie_process_function(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals:
self._test_zombie_workflow(wait_fn, s)
class StartProcessesNotCIAsBinaryTest(_StartProcessesTest):
def test_binary_signal(self):
pc = start_processes(
name="echo",
entrypoint=bin("echo3.py"),
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
),
)
results = pc.wait(period=0.1)
self.assert_pids_noexist(pc.pids())
self.assertTrue(results.is_failed())
self.assertEqual(1, len(results.failures))
failure = results.failures[0]
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
if TEST_WITH_ASAN or TEST_WITH_TSAN:
# ASAN/TSAN exit code is 1.
self.assertEqual("<N/A>", failure.signal_name())
else:
self.assertEqual("SIGSEGV", failure.signal_name())
self.assertEqual("<NONE>", failure.error_file_data["message"])
def test_no_zombie_process_binary(self): def test_no_zombie_process_binary(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
for s in signals: for s in signals:
self._test_zombie_workflow(bin("zombie_test.py"), s) self._test_zombie_workflow(bin("zombie_test.py"), s)
class ForkServerTest( def test_no_zombie_process_function(self):
StartProcessesAsFuncTest, signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
StartProcessesListAsFuncTest, for s in signals:
StartProcessesNotCIAsFuncTest, self._test_zombie_workflow(wait_fn, s)
):
def setUp(self):
super().setUp()
self._start_methods = ["forkserver"]
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
def tearDown(self): def _test_zombie_workflow(
super().tearDown() self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
if self.orig_paralell_env_val is None: ) -> None:
del os.environ[mp.ENV_VAR_PARALLEL_START] mp_queue = mp.get_context("spawn").Queue()
else: child_nproc = 2
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val ctx = mp.spawn(
start_processes_zombie_test,
nprocs=1,
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
join=False,
)
total_processes = child_nproc + 1
pids = []
for _ in range(total_processes):
pids.append(mp_queue.get(timeout=120))
parent_pid = pids[0]
child_pids = pids[1:]
os.kill(parent_pid, signal.SIGTERM)
# Wait to give time for signal handlers to finish work
time.sleep(5)
for child_pid in child_pids:
# Killing parent should kill all children, we expect that each call to
# os.kill would raise OSError
with self.assertRaises(OSError):
os.kill(child_pid, 0)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -8,14 +8,9 @@ import sys
import time import time
import unittest import unittest
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.testing._internal.common_utils import (
IS_WINDOWS,
NO_MULTIPROCESSING_SPAWN,
run_tests,
TestCase,
)
def _test_success_func(i): def _test_success_func(i):
pass pass
@ -92,7 +87,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method):
# Kill self. This should take down the child processes as well. # Kill self. This should take down the child processes as well.
os.kill(os.getpid(), signal.SIGTERM) os.kill(os.getpid(), signal.SIGTERM)
class _TestMultiProcessing(TestCase): class _TestMultiProcessing:
start_method = None start_method = None
def test_success(self): def test_success(self):
@ -194,11 +189,10 @@ class _TestMultiProcessing(TestCase):
self.assertLess(time.time() - start, nested_child_sleep / 2) self.assertLess(time.time() - start, nested_child_sleep / 2)
time.sleep(0.1) time.sleep(0.1)
@unittest.skipIf( @unittest.skipIf(
NO_MULTIPROCESSING_SPAWN, NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method") "Disabled for environments that don't support the spawn start method")
class SpawnTest(_TestMultiProcessing): class SpawnTest(TestCase, _TestMultiProcessing):
start_method = 'spawn' start_method = 'spawn'
def test_exception_raises(self): def test_exception_raises(self):
@ -222,103 +216,10 @@ class SpawnTest(_TestMultiProcessing):
IS_WINDOWS, IS_WINDOWS,
"Fork is only available on Unix", "Fork is only available on Unix",
) )
class ForkTest(_TestMultiProcessing): class ForkTest(TestCase, _TestMultiProcessing):
start_method = 'fork' start_method = 'fork'
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ForkServerTest(_TestMultiProcessing):
start_method = 'forkserver'
class _ParallelTest:
orig_paralell_env_val = None
def setUp(self):
super().setUp()
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
def tearDown(self):
super().tearDown()
if self.orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
class ParallelSpawnShouldFallbackAndWorkTest(SpawnTest, _ParallelTest):
pass
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkShouldFallbackAndWorkTest(ForkTest, _ParallelTest):
pass
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerShouldWorkTest(ForkServerTest, _ParallelTest):
pass
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerPerfTest(TestCase):
def test_forkserver_perf(self):
start_method = 'forkserver'
expensive = Expensive()
nprocs = 6
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
# test the non parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the time should be at least 6x the sleep time
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
# test the parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the time should be at most 1x the sleep time + small overhead
self.assertLess(elapsed, Expensive.SLEEP_SECS + 10)
if orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
class Expensive:
SLEEP_SECS = 10
# Simulate startup overhead such as large imports
time.sleep(SLEEP_SECS)
def __init__(self):
self.config: str = "*" * 1000000
def my_call(self, *args):
pass
class ErrorTest(TestCase): class ErrorTest(TestCase):
def test_errors_pickleable(self): def test_errors_pickleable(self):
for error in ( for error in (

View File

@ -39,7 +39,6 @@ torch._C._multiprocessing_init()
"""Add helper function to spawn N processes and wait for completion of any of """Add helper function to spawn N processes and wait for completion of any of
them. This depends `mp.get_context` which was added in Python 3.4.""" them. This depends `mp.get_context` which was added in Python 3.4."""
from .spawn import ( from .spawn import (
ENV_VAR_PARALLEL_START,
ProcessContext, ProcessContext,
ProcessExitedException, ProcessExitedException,
ProcessRaisedException, ProcessRaisedException,

View File

@ -9,26 +9,13 @@ import sys
import tempfile import tempfile
import time import time
import warnings import warnings
from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional from typing import Optional
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
__all__ = [
"ProcessContext",
"ProcessException",
"ProcessExitedException",
"ProcessRaisedException",
"spawn",
"SpawnContext",
"start_processes",
]
class ProcessException(Exception): class ProcessException(Exception):
__slots__ = ["error_index", "error_pid"] __slots__ = ["error_index", "error_pid"]
@ -218,31 +205,12 @@ class SpawnContext(ProcessContext):
# Currently we only add this API first, we can consider adding it to documentation as # Currently we only add this API first, we can consider adding it to documentation as
# needed in the future. # needed in the future.
def start_processes( def start_processes(
fn, fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"
args=(),
nprocs=1,
join=True,
daemon=False,
start_method="spawn",
): ):
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
# this func will start processes in parallel if start_method is 'forkserver'.
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
# todo: investigate why spawn does not work with threadpool and raises SIGINT
if (
start_method == "forkserver"
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
):
start_parallel = True
else:
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
start_parallel = False
mp = multiprocessing.get_context(start_method) mp = multiprocessing.get_context(start_method)
error_files = [None] * nprocs error_files = []
processes = [None] * nprocs processes = []
for i in range(nprocs):
def start_process(i):
# Each process is assigned a file to write tracebacks to. We # Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception # use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously # occurred (vs an expected shutdown). Note: this previously
@ -260,21 +228,9 @@ def start_processes(
daemon=daemon, daemon=daemon,
) )
process.start() process.start()
return i, process, tf.name error_files.append(tf.name)
processes.append(process)
if not start_parallel:
for i in range(nprocs):
idx, process, tf_name = start_process(i)
error_files[idx] = tf_name
processes[idx] = process
else:
with ThreadPoolExecutor(max_workers=nprocs) as executor:
futures = [executor.submit(start_process, i) for i in range(nprocs)]
for fut in as_completed(futures):
idx, process, tf_name = fut.result()
# idx and process rank needs to be the same.
error_files[idx] = tf_name
processes[idx] = process
context = ProcessContext(processes, error_files) context = ProcessContext(processes, error_files)
if not join: if not join:
return context return context