mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Create processes in parallel in mp.start_processes for forkserver (#134629)
Summary: This is to fix the pytorch issue filed https://github.com/pytorch/pytorch/issues/133010 one way to fix this problem is to enable parallel start processes in mp.start_processes. What else in the diff: refactored a test case api_test which was repeating a lot of tests due to the inheritance. added unit test for forkserver when parallel start is on. Test Plan: Added unit tests Differential Revision: D61878552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134629 Approved by: https://github.com/d4l3k
This commit is contained in:
@ -226,36 +226,65 @@ def start_processes_zombie_test(
|
||||
pc.close(e.sigval)
|
||||
|
||||
|
||||
class _StartProcessesTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
|
||||
self._start_methods = ["spawn"]
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def log_dir(self):
|
||||
return tempfile.mkdtemp(dir=self.test_dir)
|
||||
|
||||
def assert_in_file(self, expected: List[str], filename: str) -> None:
|
||||
expected = [f"{line.rstrip()}\n" for line in expected]
|
||||
with open(filename) as fp:
|
||||
actual = fp.readlines()
|
||||
for line in expected:
|
||||
self.assertIn(line, actual)
|
||||
|
||||
def assert_pids_noexist(self, pids: Dict[int, int]):
|
||||
for local_rank, pid in pids.items():
|
||||
with self.assertRaises(
|
||||
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
|
||||
):
|
||||
os.kill(pid, 0)
|
||||
|
||||
def _test_zombie_workflow(
|
||||
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
|
||||
) -> None:
|
||||
mp_queue = mp.get_context("spawn").Queue()
|
||||
child_nproc = 2
|
||||
ctx = mp.spawn(
|
||||
start_processes_zombie_test,
|
||||
nprocs=1,
|
||||
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
||||
join=False,
|
||||
)
|
||||
total_processes = child_nproc + 1
|
||||
pids = []
|
||||
for _ in range(total_processes):
|
||||
pids.append(mp_queue.get(timeout=120))
|
||||
parent_pid = pids[0]
|
||||
child_pids = pids[1:]
|
||||
|
||||
os.kill(parent_pid, signal.SIGTERM)
|
||||
# Wait to give time for signal handlers to finish work
|
||||
time.sleep(5)
|
||||
for child_pid in child_pids:
|
||||
# Killing parent should kill all children, we expect that each call to
|
||||
# os.kill would raise OSError
|
||||
with self.assertRaises(OSError):
|
||||
os.kill(child_pid, 0)
|
||||
|
||||
|
||||
# tests incompatible with tsan or asan
|
||||
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
|
||||
class StartProcessesTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
|
||||
self._start_methods = ["spawn"]
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def log_dir(self):
|
||||
return tempfile.mkdtemp(dir=self.test_dir)
|
||||
|
||||
def assert_in_file(self, expected: List[str], filename: str) -> None:
|
||||
expected = [f"{line.rstrip()}\n" for line in expected]
|
||||
with open(filename) as fp:
|
||||
actual = fp.readlines()
|
||||
for line in expected:
|
||||
self.assertIn(line, actual)
|
||||
|
||||
def assert_pids_noexist(self, pids: Dict[int, int]):
|
||||
for local_rank, pid in pids.items():
|
||||
with self.assertRaises(
|
||||
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
|
||||
):
|
||||
os.kill(pid, 0)
|
||||
|
||||
class StartProcessesAsFuncTest(_StartProcessesTest):
|
||||
def test_to_map(self):
|
||||
local_world_size = 2
|
||||
self.assertEqual(
|
||||
@ -349,26 +378,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
|
||||
def test_subprocess_context_close(self):
|
||||
pc = start_processes(
|
||||
name="sleep",
|
||||
entrypoint=bin("zombie_test.py"),
|
||||
args={0: (1,)},
|
||||
envs={0: {}},
|
||||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||
)
|
||||
|
||||
pids = pc.pids()
|
||||
pc.close()
|
||||
self.assert_pids_noexist(pids)
|
||||
|
||||
def test_function_with_tensor(self):
|
||||
for start_method in self._start_methods:
|
||||
pc = start_processes(
|
||||
name="dummy_compute",
|
||||
entrypoint=dummy_compute,
|
||||
args={},
|
||||
envs={},
|
||||
args={0: ()},
|
||||
envs={0: {}},
|
||||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||
start_method=start_method,
|
||||
)
|
||||
@ -489,10 +505,55 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
mpc._poll()
|
||||
self.assertEqual(4, mock_join.call_count)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
"Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method",
|
||||
)
|
||||
def test_multiprocessing_context_poll_raises_exception(self):
|
||||
mp_context = MultiprocessContext(
|
||||
name="test_mp",
|
||||
entrypoint=echo0,
|
||||
args={0: (0, 1)},
|
||||
envs={0: {}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
|
||||
),
|
||||
start_method="spawn",
|
||||
)
|
||||
mp_context._pc = mock.Mock()
|
||||
# Using mock since we cannot just set exitcode on process
|
||||
mock_process = mock.Mock()
|
||||
mock_process.exitcode = -1
|
||||
mp_context._pc.processes = [mock_process]
|
||||
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
|
||||
mp_context._pc.join.side_effect = e
|
||||
with mock.patch.object(mp_context, "close"):
|
||||
run_result = mp_context._poll()
|
||||
self.assertEqual(1, len(run_result.failures))
|
||||
failure = run_result.failures[0]
|
||||
self.assertEqual(
|
||||
"Signal 1 (SIGHUP) received by PID 123", failure.message
|
||||
)
|
||||
|
||||
class StartProcessesAsBinaryTest(_StartProcessesTest):
|
||||
########################################
|
||||
# start_processes as binary tests
|
||||
########################################
|
||||
|
||||
def test_subprocess_context_close(self):
|
||||
pc = start_processes(
|
||||
name="sleep",
|
||||
entrypoint=bin("zombie_test.py"),
|
||||
args={0: (1,)},
|
||||
envs={0: {}},
|
||||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||
)
|
||||
|
||||
pids = pc.pids()
|
||||
pc.close()
|
||||
self.assert_pids_noexist(pids)
|
||||
|
||||
def test_binary_exit(self):
|
||||
FAIL = 138
|
||||
pc = start_processes(
|
||||
@ -556,45 +617,11 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_validate_full_rank({}, 10, "")
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
"Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method",
|
||||
)
|
||||
def test_multiprocessing_context_poll_raises_exception(self):
|
||||
mp_context = MultiprocessContext(
|
||||
name="test_mp",
|
||||
entrypoint=echo0,
|
||||
args={0: (0, 1)},
|
||||
envs={0: {}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
|
||||
),
|
||||
start_method="spawn",
|
||||
)
|
||||
mp_context._pc = mock.Mock()
|
||||
# Using mock since we cannot just set exitcode on process
|
||||
mock_process = mock.Mock()
|
||||
mock_process.exitcode = -1
|
||||
mp_context._pc.processes = [mock_process]
|
||||
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
|
||||
mp_context._pc.join.side_effect = e
|
||||
with mock.patch.object(mp_context, "close"):
|
||||
run_result = mp_context._poll()
|
||||
self.assertEqual(1, len(run_result.failures))
|
||||
failure = run_result.failures[0]
|
||||
self.assertEqual(
|
||||
"Signal 1 (SIGHUP) received by PID 123", failure.message
|
||||
)
|
||||
|
||||
|
||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
|
||||
class StartProcessesListTest(StartProcessesTest):
|
||||
########################################
|
||||
# start_processes as binary tests
|
||||
########################################
|
||||
class StartProcessesListAsFuncTest(_StartProcessesTest):
|
||||
def test_function(self):
|
||||
for start_method, redirs in product(
|
||||
self._start_methods, redirects_oss_test()
|
||||
@ -634,6 +661,10 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
[f"hello stderr from {i}"], results.stderrs[i]
|
||||
)
|
||||
|
||||
class StartProcessesListAsBinaryTest(_StartProcessesTest):
|
||||
########################################
|
||||
# start_processes as binary tests
|
||||
########################################
|
||||
def test_binary(self):
|
||||
for redirs in redirects_oss_test():
|
||||
with self.subTest(redirs=redirs):
|
||||
@ -700,7 +731,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
|
||||
class StartProcessesNotCITest(StartProcessesTest):
|
||||
class StartProcessesNotCIAsFuncTest(_StartProcessesTest):
|
||||
@skip_if_pytest
|
||||
def test_wrap_bad(self):
|
||||
none = ""
|
||||
@ -733,32 +764,6 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
||||
worker_finished_event_mock.wait.assert_called_once()
|
||||
|
||||
def test_binary_signal(self):
|
||||
pc = start_processes(
|
||||
name="echo",
|
||||
entrypoint=bin("echo3.py"),
|
||||
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
),
|
||||
)
|
||||
|
||||
results = pc.wait(period=0.1)
|
||||
|
||||
self.assert_pids_noexist(pc.pids())
|
||||
self.assertTrue(results.is_failed())
|
||||
self.assertEqual(1, len(results.failures))
|
||||
|
||||
failure = results.failures[0]
|
||||
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
|
||||
if TEST_WITH_ASAN or TEST_WITH_TSAN:
|
||||
# ASAN/TSAN exit code is 1.
|
||||
self.assertEqual("<N/A>", failure.signal_name())
|
||||
else:
|
||||
self.assertEqual("SIGSEGV", failure.signal_name())
|
||||
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
||||
|
||||
def test_function_redirect_and_tee(self):
|
||||
for start_method in self._start_methods:
|
||||
with self.subTest(start_method=start_method):
|
||||
@ -871,42 +876,60 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
|
||||
def test_no_zombie_process_binary(self):
|
||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||
for s in signals:
|
||||
self._test_zombie_workflow(bin("zombie_test.py"), s)
|
||||
|
||||
def test_no_zombie_process_function(self):
|
||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||
for s in signals:
|
||||
self._test_zombie_workflow(wait_fn, s)
|
||||
|
||||
def _test_zombie_workflow(
|
||||
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
|
||||
) -> None:
|
||||
mp_queue = mp.get_context("spawn").Queue()
|
||||
child_nproc = 2
|
||||
ctx = mp.spawn(
|
||||
start_processes_zombie_test,
|
||||
nprocs=1,
|
||||
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
||||
join=False,
|
||||
class StartProcessesNotCIAsBinaryTest(_StartProcessesTest):
|
||||
def test_binary_signal(self):
|
||||
pc = start_processes(
|
||||
name="echo",
|
||||
entrypoint=bin("echo3.py"),
|
||||
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
),
|
||||
)
|
||||
total_processes = child_nproc + 1
|
||||
pids = []
|
||||
for _ in range(total_processes):
|
||||
pids.append(mp_queue.get(timeout=120))
|
||||
parent_pid = pids[0]
|
||||
child_pids = pids[1:]
|
||||
|
||||
os.kill(parent_pid, signal.SIGTERM)
|
||||
# Wait to give time for signal handlers to finish work
|
||||
time.sleep(5)
|
||||
for child_pid in child_pids:
|
||||
# Killing parent should kill all children, we expect that each call to
|
||||
# os.kill would raise OSError
|
||||
with self.assertRaises(OSError):
|
||||
os.kill(child_pid, 0)
|
||||
results = pc.wait(period=0.1)
|
||||
|
||||
self.assert_pids_noexist(pc.pids())
|
||||
self.assertTrue(results.is_failed())
|
||||
self.assertEqual(1, len(results.failures))
|
||||
|
||||
failure = results.failures[0]
|
||||
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
|
||||
if TEST_WITH_ASAN or TEST_WITH_TSAN:
|
||||
# ASAN/TSAN exit code is 1.
|
||||
self.assertEqual("<N/A>", failure.signal_name())
|
||||
else:
|
||||
self.assertEqual("SIGSEGV", failure.signal_name())
|
||||
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
||||
|
||||
def test_no_zombie_process_binary(self):
|
||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||
for s in signals:
|
||||
self._test_zombie_workflow(bin("zombie_test.py"), s)
|
||||
|
||||
class ForkServerTest(
|
||||
StartProcessesAsFuncTest,
|
||||
StartProcessesListAsFuncTest,
|
||||
StartProcessesNotCIAsFuncTest,
|
||||
):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._start_methods = ["forkserver"]
|
||||
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
if self.orig_paralell_env_val is None:
|
||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
||||
else:
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -8,9 +8,14 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
def _test_success_func(i):
|
||||
pass
|
||||
@ -220,6 +225,73 @@ class ForkTest(TestCase, _TestMultiProcessing):
|
||||
start_method = 'fork'
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
|
||||
orig_paralell_env_val = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
if self.orig_paralell_env_val is None:
|
||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
||||
else:
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ParallelForkServerPerfTest(TestCase):
|
||||
|
||||
def test_forkserver_perf(self):
|
||||
|
||||
start_method = 'forkserver'
|
||||
expensive = Expensive()
|
||||
nprocs = 4
|
||||
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
||||
|
||||
# test the non parallel case
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
|
||||
start = time.perf_counter()
|
||||
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
|
||||
elapsed = time.perf_counter() - start
|
||||
# the elapsed time should be at least {nprocs}x the sleep time
|
||||
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
|
||||
|
||||
# test the parallel case
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
||||
start = time.perf_counter()
|
||||
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
|
||||
elapsed = time.perf_counter() - start
|
||||
# the elapsed time should be less than {nprocs}x the sleep time
|
||||
self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs)
|
||||
|
||||
if orig_paralell_env_val is None:
|
||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
||||
else:
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
|
||||
|
||||
|
||||
class Expensive:
|
||||
SLEEP_SECS = 5
|
||||
# Simulate startup overhead such as large imports
|
||||
time.sleep(SLEEP_SECS)
|
||||
|
||||
def __init__(self):
|
||||
self.config: str = "*" * 1000000
|
||||
|
||||
def my_call(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class ErrorTest(TestCase):
|
||||
def test_errors_pickleable(self):
|
||||
for error in (
|
||||
|
@ -39,6 +39,7 @@ torch._C._multiprocessing_init()
|
||||
"""Add helper function to spawn N processes and wait for completion of any of
|
||||
them. This depends `mp.get_context` which was added in Python 3.4."""
|
||||
from .spawn import (
|
||||
ENV_VAR_PARALLEL_START,
|
||||
ProcessContext,
|
||||
ProcessExitedException,
|
||||
ProcessRaisedException,
|
||||
|
@ -9,13 +9,26 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import as_completed, ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
|
||||
|
||||
|
||||
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"ProcessContext",
|
||||
"ProcessException",
|
||||
"ProcessExitedException",
|
||||
"ProcessRaisedException",
|
||||
"spawn",
|
||||
"SpawnContext",
|
||||
"start_processes",
|
||||
]
|
||||
|
||||
|
||||
class ProcessException(Exception):
|
||||
__slots__ = ["error_index", "error_pid"]
|
||||
@ -205,12 +218,32 @@ class SpawnContext(ProcessContext):
|
||||
# Currently we only add this API first, we can consider adding it to documentation as
|
||||
# needed in the future.
|
||||
def start_processes(
|
||||
fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"
|
||||
fn,
|
||||
args=(),
|
||||
nprocs=1,
|
||||
join=True,
|
||||
daemon=False,
|
||||
start_method="spawn",
|
||||
):
|
||||
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
|
||||
# this func will start processes in parallel if start_method is 'forkserver'.
|
||||
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
|
||||
# todo: investigate why spawn does not work with threadpool and raises SIGINT
|
||||
if (
|
||||
start_method == "forkserver"
|
||||
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
|
||||
):
|
||||
log.info("Starting processes in parallel.")
|
||||
start_parallel = True
|
||||
else:
|
||||
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
|
||||
start_parallel = False
|
||||
|
||||
mp = multiprocessing.get_context(start_method)
|
||||
error_files = []
|
||||
processes = []
|
||||
for i in range(nprocs):
|
||||
error_files = [None] * nprocs
|
||||
processes = [None] * nprocs
|
||||
|
||||
def start_process(i):
|
||||
# Each process is assigned a file to write tracebacks to. We
|
||||
# use the file being non-empty to indicate an exception
|
||||
# occurred (vs an expected shutdown). Note: this previously
|
||||
@ -228,9 +261,21 @@ def start_processes(
|
||||
daemon=daemon,
|
||||
)
|
||||
process.start()
|
||||
error_files.append(tf.name)
|
||||
processes.append(process)
|
||||
return i, process, tf.name
|
||||
|
||||
if not start_parallel:
|
||||
for i in range(nprocs):
|
||||
idx, process, tf_name = start_process(i)
|
||||
error_files[idx] = tf_name
|
||||
processes[idx] = process
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=nprocs) as executor:
|
||||
futures = [executor.submit(start_process, i) for i in range(nprocs)]
|
||||
for fut in as_completed(futures):
|
||||
idx, process, tf_name = fut.result()
|
||||
# idx and process rank needs to be the same.
|
||||
error_files[idx] = tf_name
|
||||
processes[idx] = process
|
||||
context = ProcessContext(processes, error_files)
|
||||
if not join:
|
||||
return context
|
||||
|
Reference in New Issue
Block a user