mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
Revert "Allow mp.start_processes to create processes in parallel (#133707)"
This reverts commit 3546628a2a167ace6060737eeccf8ee8fd87ddc0.
Reverted https://github.com/pytorch/pytorch/pull/133707 on behalf of https://github.com/ZainRizvi due to sorry but trunk has been consistently broken since this PR was merged. See: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10529617600/job/29191757055) [HUD commit link](3546628a2a) ([comment](https://github.com/pytorch/pytorch/pull/133707#issuecomment-2310709523))
This commit is contained in:
@ -226,65 +226,36 @@ def start_processes_zombie_test(
|
|||||||
pc.close(e.sigval)
|
pc.close(e.sigval)
|
||||||
|
|
||||||
|
|
||||||
class _StartProcessesTest(TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
|
|
||||||
self._start_methods = ["spawn"]
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
shutil.rmtree(self.test_dir)
|
|
||||||
|
|
||||||
def log_dir(self):
|
|
||||||
return tempfile.mkdtemp(dir=self.test_dir)
|
|
||||||
|
|
||||||
def assert_in_file(self, expected: List[str], filename: str) -> None:
|
|
||||||
expected = [f"{line.rstrip()}\n" for line in expected]
|
|
||||||
with open(filename) as fp:
|
|
||||||
actual = fp.readlines()
|
|
||||||
for line in expected:
|
|
||||||
self.assertIn(line, actual)
|
|
||||||
|
|
||||||
def assert_pids_noexist(self, pids: Dict[int, int]):
|
|
||||||
for local_rank, pid in pids.items():
|
|
||||||
with self.assertRaises(
|
|
||||||
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
|
|
||||||
):
|
|
||||||
os.kill(pid, 0)
|
|
||||||
|
|
||||||
def _test_zombie_workflow(
|
|
||||||
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
|
|
||||||
) -> None:
|
|
||||||
mp_queue = mp.get_context("spawn").Queue()
|
|
||||||
child_nproc = 2
|
|
||||||
ctx = mp.spawn(
|
|
||||||
start_processes_zombie_test,
|
|
||||||
nprocs=1,
|
|
||||||
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
|
||||||
join=False,
|
|
||||||
)
|
|
||||||
total_processes = child_nproc + 1
|
|
||||||
pids = []
|
|
||||||
for _ in range(total_processes):
|
|
||||||
pids.append(mp_queue.get(timeout=120))
|
|
||||||
parent_pid = pids[0]
|
|
||||||
child_pids = pids[1:]
|
|
||||||
|
|
||||||
os.kill(parent_pid, signal.SIGTERM)
|
|
||||||
# Wait to give time for signal handlers to finish work
|
|
||||||
time.sleep(5)
|
|
||||||
for child_pid in child_pids:
|
|
||||||
# Killing parent should kill all children, we expect that each call to
|
|
||||||
# os.kill would raise OSError
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
os.kill(child_pid, 0)
|
|
||||||
|
|
||||||
|
|
||||||
# tests incompatible with tsan or asan
|
# tests incompatible with tsan or asan
|
||||||
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||||
|
|
||||||
class StartProcessesAsFuncTest(_StartProcessesTest):
|
class StartProcessesTest(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
|
||||||
|
self._start_methods = ["spawn"]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
shutil.rmtree(self.test_dir)
|
||||||
|
|
||||||
|
def log_dir(self):
|
||||||
|
return tempfile.mkdtemp(dir=self.test_dir)
|
||||||
|
|
||||||
|
def assert_in_file(self, expected: List[str], filename: str) -> None:
|
||||||
|
expected = [f"{line.rstrip()}\n" for line in expected]
|
||||||
|
with open(filename) as fp:
|
||||||
|
actual = fp.readlines()
|
||||||
|
for line in expected:
|
||||||
|
self.assertIn(line, actual)
|
||||||
|
|
||||||
|
def assert_pids_noexist(self, pids: Dict[int, int]):
|
||||||
|
for local_rank, pid in pids.items():
|
||||||
|
with self.assertRaises(
|
||||||
|
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
|
||||||
|
):
|
||||||
|
os.kill(pid, 0)
|
||||||
|
|
||||||
def test_to_map(self):
|
def test_to_map(self):
|
||||||
local_world_size = 2
|
local_world_size = 2
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -378,13 +349,26 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||||||
self.assertTrue(pc._stderr_tail.stopped())
|
self.assertTrue(pc._stderr_tail.stopped())
|
||||||
self.assertTrue(pc._stdout_tail.stopped())
|
self.assertTrue(pc._stdout_tail.stopped())
|
||||||
|
|
||||||
|
def test_subprocess_context_close(self):
|
||||||
|
pc = start_processes(
|
||||||
|
name="sleep",
|
||||||
|
entrypoint=bin("zombie_test.py"),
|
||||||
|
args={0: (1,)},
|
||||||
|
envs={0: {}},
|
||||||
|
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||||
|
)
|
||||||
|
|
||||||
|
pids = pc.pids()
|
||||||
|
pc.close()
|
||||||
|
self.assert_pids_noexist(pids)
|
||||||
|
|
||||||
def test_function_with_tensor(self):
|
def test_function_with_tensor(self):
|
||||||
for start_method in self._start_methods:
|
for start_method in self._start_methods:
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
name="dummy_compute",
|
name="dummy_compute",
|
||||||
entrypoint=dummy_compute,
|
entrypoint=dummy_compute,
|
||||||
args={0: ()},
|
args={},
|
||||||
envs={0: {}},
|
envs={},
|
||||||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
)
|
)
|
||||||
@ -505,55 +489,10 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||||||
mpc._poll()
|
mpc._poll()
|
||||||
self.assertEqual(4, mock_join.call_count)
|
self.assertEqual(4, mock_join.call_count)
|
||||||
|
|
||||||
@skip_but_pass_in_sandcastle_if(
|
|
||||||
NO_MULTIPROCESSING_SPAWN,
|
|
||||||
"Disabled for environments that \
|
|
||||||
don't support multiprocessing with spawn start method",
|
|
||||||
)
|
|
||||||
def test_multiprocessing_context_poll_raises_exception(self):
|
|
||||||
mp_context = MultiprocessContext(
|
|
||||||
name="test_mp",
|
|
||||||
entrypoint=echo0,
|
|
||||||
args={0: (0, 1)},
|
|
||||||
envs={0: {}},
|
|
||||||
logs_specs=DefaultLogsSpecs(
|
|
||||||
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
|
|
||||||
),
|
|
||||||
start_method="spawn",
|
|
||||||
)
|
|
||||||
mp_context._pc = mock.Mock()
|
|
||||||
# Using mock since we cannot just set exitcode on process
|
|
||||||
mock_process = mock.Mock()
|
|
||||||
mock_process.exitcode = -1
|
|
||||||
mp_context._pc.processes = [mock_process]
|
|
||||||
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
|
|
||||||
mp_context._pc.join.side_effect = e
|
|
||||||
with mock.patch.object(mp_context, "close"):
|
|
||||||
run_result = mp_context._poll()
|
|
||||||
self.assertEqual(1, len(run_result.failures))
|
|
||||||
failure = run_result.failures[0]
|
|
||||||
self.assertEqual(
|
|
||||||
"Signal 1 (SIGHUP) received by PID 123", failure.message
|
|
||||||
)
|
|
||||||
|
|
||||||
class StartProcessesAsBinaryTest(_StartProcessesTest):
|
|
||||||
########################################
|
########################################
|
||||||
# start_processes as binary tests
|
# start_processes as binary tests
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
def test_subprocess_context_close(self):
|
|
||||||
pc = start_processes(
|
|
||||||
name="sleep",
|
|
||||||
entrypoint=bin("zombie_test.py"),
|
|
||||||
args={0: (1,)},
|
|
||||||
envs={0: {}},
|
|
||||||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
|
||||||
)
|
|
||||||
|
|
||||||
pids = pc.pids()
|
|
||||||
pc.close()
|
|
||||||
self.assert_pids_noexist(pids)
|
|
||||||
|
|
||||||
def test_binary_exit(self):
|
def test_binary_exit(self):
|
||||||
FAIL = 138
|
FAIL = 138
|
||||||
pc = start_processes(
|
pc = start_processes(
|
||||||
@ -617,11 +556,45 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_validate_full_rank({}, 10, "")
|
_validate_full_rank({}, 10, "")
|
||||||
|
|
||||||
|
@skip_but_pass_in_sandcastle_if(
|
||||||
|
NO_MULTIPROCESSING_SPAWN,
|
||||||
|
"Disabled for environments that \
|
||||||
|
don't support multiprocessing with spawn start method",
|
||||||
|
)
|
||||||
|
def test_multiprocessing_context_poll_raises_exception(self):
|
||||||
|
mp_context = MultiprocessContext(
|
||||||
|
name="test_mp",
|
||||||
|
entrypoint=echo0,
|
||||||
|
args={0: (0, 1)},
|
||||||
|
envs={0: {}},
|
||||||
|
logs_specs=DefaultLogsSpecs(
|
||||||
|
log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
|
||||||
|
),
|
||||||
|
start_method="spawn",
|
||||||
|
)
|
||||||
|
mp_context._pc = mock.Mock()
|
||||||
|
# Using mock since we cannot just set exitcode on process
|
||||||
|
mock_process = mock.Mock()
|
||||||
|
mock_process.exitcode = -1
|
||||||
|
mp_context._pc.processes = [mock_process]
|
||||||
|
e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
|
||||||
|
mp_context._pc.join.side_effect = e
|
||||||
|
with mock.patch.object(mp_context, "close"):
|
||||||
|
run_result = mp_context._poll()
|
||||||
|
self.assertEqual(1, len(run_result.failures))
|
||||||
|
failure = run_result.failures[0]
|
||||||
|
self.assertEqual(
|
||||||
|
"Signal 1 (SIGHUP) received by PID 123", failure.message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||||
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||||
|
|
||||||
class StartProcessesListAsFuncTest(_StartProcessesTest):
|
class StartProcessesListTest(StartProcessesTest):
|
||||||
|
########################################
|
||||||
|
# start_processes as binary tests
|
||||||
|
########################################
|
||||||
def test_function(self):
|
def test_function(self):
|
||||||
for start_method, redirs in product(
|
for start_method, redirs in product(
|
||||||
self._start_methods, redirects_oss_test()
|
self._start_methods, redirects_oss_test()
|
||||||
@ -661,10 +634,6 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||||||
[f"hello stderr from {i}"], results.stderrs[i]
|
[f"hello stderr from {i}"], results.stderrs[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
class StartProcessesListAsBinaryTest(_StartProcessesTest):
|
|
||||||
########################################
|
|
||||||
# start_processes as binary tests
|
|
||||||
########################################
|
|
||||||
def test_binary(self):
|
def test_binary(self):
|
||||||
for redirs in redirects_oss_test():
|
for redirs in redirects_oss_test():
|
||||||
with self.subTest(redirs=redirs):
|
with self.subTest(redirs=redirs):
|
||||||
@ -731,7 +700,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||||
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
||||||
|
|
||||||
class StartProcessesNotCIAsFuncTest(_StartProcessesTest):
|
class StartProcessesNotCITest(StartProcessesTest):
|
||||||
@skip_if_pytest
|
@skip_if_pytest
|
||||||
def test_wrap_bad(self):
|
def test_wrap_bad(self):
|
||||||
none = ""
|
none = ""
|
||||||
@ -764,6 +733,32 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
|||||||
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
self.assert_in_file(["hello stderr from 0"], stderr_log)
|
||||||
worker_finished_event_mock.wait.assert_called_once()
|
worker_finished_event_mock.wait.assert_called_once()
|
||||||
|
|
||||||
|
def test_binary_signal(self):
|
||||||
|
pc = start_processes(
|
||||||
|
name="echo",
|
||||||
|
entrypoint=bin("echo3.py"),
|
||||||
|
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
||||||
|
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||||
|
logs_specs=DefaultLogsSpecs(
|
||||||
|
log_dir=self.log_dir(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = pc.wait(period=0.1)
|
||||||
|
|
||||||
|
self.assert_pids_noexist(pc.pids())
|
||||||
|
self.assertTrue(results.is_failed())
|
||||||
|
self.assertEqual(1, len(results.failures))
|
||||||
|
|
||||||
|
failure = results.failures[0]
|
||||||
|
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
|
||||||
|
if TEST_WITH_ASAN or TEST_WITH_TSAN:
|
||||||
|
# ASAN/TSAN exit code is 1.
|
||||||
|
self.assertEqual("<N/A>", failure.signal_name())
|
||||||
|
else:
|
||||||
|
self.assertEqual("SIGSEGV", failure.signal_name())
|
||||||
|
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
||||||
|
|
||||||
def test_function_redirect_and_tee(self):
|
def test_function_redirect_and_tee(self):
|
||||||
for start_method in self._start_methods:
|
for start_method in self._start_methods:
|
||||||
with self.subTest(start_method=start_method):
|
with self.subTest(start_method=start_method):
|
||||||
@ -876,60 +871,42 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
|||||||
self.assertTrue(pc._stderr_tail.stopped())
|
self.assertTrue(pc._stderr_tail.stopped())
|
||||||
self.assertTrue(pc._stdout_tail.stopped())
|
self.assertTrue(pc._stdout_tail.stopped())
|
||||||
|
|
||||||
def test_no_zombie_process_function(self):
|
|
||||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
|
||||||
for s in signals:
|
|
||||||
self._test_zombie_workflow(wait_fn, s)
|
|
||||||
|
|
||||||
class StartProcessesNotCIAsBinaryTest(_StartProcessesTest):
|
|
||||||
def test_binary_signal(self):
|
|
||||||
pc = start_processes(
|
|
||||||
name="echo",
|
|
||||||
entrypoint=bin("echo3.py"),
|
|
||||||
args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
|
|
||||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
|
||||||
logs_specs=DefaultLogsSpecs(
|
|
||||||
log_dir=self.log_dir(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
results = pc.wait(period=0.1)
|
|
||||||
|
|
||||||
self.assert_pids_noexist(pc.pids())
|
|
||||||
self.assertTrue(results.is_failed())
|
|
||||||
self.assertEqual(1, len(results.failures))
|
|
||||||
|
|
||||||
failure = results.failures[0]
|
|
||||||
self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
|
|
||||||
if TEST_WITH_ASAN or TEST_WITH_TSAN:
|
|
||||||
# ASAN/TSAN exit code is 1.
|
|
||||||
self.assertEqual("<N/A>", failure.signal_name())
|
|
||||||
else:
|
|
||||||
self.assertEqual("SIGSEGV", failure.signal_name())
|
|
||||||
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
|
||||||
|
|
||||||
def test_no_zombie_process_binary(self):
|
def test_no_zombie_process_binary(self):
|
||||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||||
for s in signals:
|
for s in signals:
|
||||||
self._test_zombie_workflow(bin("zombie_test.py"), s)
|
self._test_zombie_workflow(bin("zombie_test.py"), s)
|
||||||
|
|
||||||
class ForkServerTest(
|
def test_no_zombie_process_function(self):
|
||||||
StartProcessesAsFuncTest,
|
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||||
StartProcessesListAsFuncTest,
|
for s in signals:
|
||||||
StartProcessesNotCIAsFuncTest,
|
self._test_zombie_workflow(wait_fn, s)
|
||||||
):
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
self._start_methods = ["forkserver"]
|
|
||||||
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
|
||||||
|
|
||||||
def tearDown(self):
|
def _test_zombie_workflow(
|
||||||
super().tearDown()
|
self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
|
||||||
if self.orig_paralell_env_val is None:
|
) -> None:
|
||||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
mp_queue = mp.get_context("spawn").Queue()
|
||||||
else:
|
child_nproc = 2
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
|
ctx = mp.spawn(
|
||||||
|
start_processes_zombie_test,
|
||||||
|
nprocs=1,
|
||||||
|
args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
|
||||||
|
join=False,
|
||||||
|
)
|
||||||
|
total_processes = child_nproc + 1
|
||||||
|
pids = []
|
||||||
|
for _ in range(total_processes):
|
||||||
|
pids.append(mp_queue.get(timeout=120))
|
||||||
|
parent_pid = pids[0]
|
||||||
|
child_pids = pids[1:]
|
||||||
|
|
||||||
|
os.kill(parent_pid, signal.SIGTERM)
|
||||||
|
# Wait to give time for signal handlers to finish work
|
||||||
|
time.sleep(5)
|
||||||
|
for child_pid in child_pids:
|
||||||
|
# Killing parent should kill all children, we expect that each call to
|
||||||
|
# os.kill would raise OSError
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
os.kill(child_pid, 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -8,14 +8,9 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import (
|
|
||||||
IS_WINDOWS,
|
|
||||||
NO_MULTIPROCESSING_SPAWN,
|
|
||||||
run_tests,
|
|
||||||
TestCase,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _test_success_func(i):
|
def _test_success_func(i):
|
||||||
pass
|
pass
|
||||||
@ -92,7 +87,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method):
|
|||||||
# Kill self. This should take down the child processes as well.
|
# Kill self. This should take down the child processes as well.
|
||||||
os.kill(os.getpid(), signal.SIGTERM)
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
|
|
||||||
class _TestMultiProcessing(TestCase):
|
class _TestMultiProcessing:
|
||||||
start_method = None
|
start_method = None
|
||||||
|
|
||||||
def test_success(self):
|
def test_success(self):
|
||||||
@ -194,11 +189,10 @@ class _TestMultiProcessing(TestCase):
|
|||||||
self.assertLess(time.time() - start, nested_child_sleep / 2)
|
self.assertLess(time.time() - start, nested_child_sleep / 2)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
NO_MULTIPROCESSING_SPAWN,
|
NO_MULTIPROCESSING_SPAWN,
|
||||||
"Disabled for environments that don't support the spawn start method")
|
"Disabled for environments that don't support the spawn start method")
|
||||||
class SpawnTest(_TestMultiProcessing):
|
class SpawnTest(TestCase, _TestMultiProcessing):
|
||||||
start_method = 'spawn'
|
start_method = 'spawn'
|
||||||
|
|
||||||
def test_exception_raises(self):
|
def test_exception_raises(self):
|
||||||
@ -222,103 +216,10 @@ class SpawnTest(_TestMultiProcessing):
|
|||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
"Fork is only available on Unix",
|
"Fork is only available on Unix",
|
||||||
)
|
)
|
||||||
class ForkTest(_TestMultiProcessing):
|
class ForkTest(TestCase, _TestMultiProcessing):
|
||||||
start_method = 'fork'
|
start_method = 'fork'
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
IS_WINDOWS,
|
|
||||||
"Fork is only available on Unix",
|
|
||||||
)
|
|
||||||
class ForkServerTest(_TestMultiProcessing):
|
|
||||||
start_method = 'forkserver'
|
|
||||||
|
|
||||||
|
|
||||||
class _ParallelTest:
|
|
||||||
orig_paralell_env_val = None
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
if self.orig_paralell_env_val is None:
|
|
||||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
|
||||||
else:
|
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
NO_MULTIPROCESSING_SPAWN,
|
|
||||||
"Disabled for environments that don't support the spawn start method")
|
|
||||||
class ParallelSpawnShouldFallbackAndWorkTest(SpawnTest, _ParallelTest):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
IS_WINDOWS,
|
|
||||||
"Fork is only available on Unix",
|
|
||||||
)
|
|
||||||
class ParallelForkShouldFallbackAndWorkTest(ForkTest, _ParallelTest):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
IS_WINDOWS,
|
|
||||||
"Fork is only available on Unix",
|
|
||||||
)
|
|
||||||
class ParallelForkServerShouldWorkTest(ForkServerTest, _ParallelTest):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
IS_WINDOWS,
|
|
||||||
"Fork is only available on Unix",
|
|
||||||
)
|
|
||||||
class ParallelForkServerPerfTest(TestCase):
|
|
||||||
def test_forkserver_perf(self):
|
|
||||||
start_method = 'forkserver'
|
|
||||||
expensive = Expensive()
|
|
||||||
nprocs = 6
|
|
||||||
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
|
||||||
|
|
||||||
# test the non parallel case
|
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
|
|
||||||
start = time.perf_counter()
|
|
||||||
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
# the time should be at least 6x the sleep time
|
|
||||||
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
|
|
||||||
|
|
||||||
# test the parallel case
|
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
|
||||||
start = time.perf_counter()
|
|
||||||
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
# the time should be at most 1x the sleep time + small overhead
|
|
||||||
self.assertLess(elapsed, Expensive.SLEEP_SECS + 10)
|
|
||||||
|
|
||||||
if orig_paralell_env_val is None:
|
|
||||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
|
||||||
else:
|
|
||||||
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
|
|
||||||
|
|
||||||
|
|
||||||
class Expensive:
|
|
||||||
SLEEP_SECS = 10
|
|
||||||
# Simulate startup overhead such as large imports
|
|
||||||
time.sleep(SLEEP_SECS)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.config: str = "*" * 1000000
|
|
||||||
|
|
||||||
def my_call(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorTest(TestCase):
|
class ErrorTest(TestCase):
|
||||||
def test_errors_pickleable(self):
|
def test_errors_pickleable(self):
|
||||||
for error in (
|
for error in (
|
||||||
|
|||||||
@ -39,7 +39,6 @@ torch._C._multiprocessing_init()
|
|||||||
"""Add helper function to spawn N processes and wait for completion of any of
|
"""Add helper function to spawn N processes and wait for completion of any of
|
||||||
them. This depends `mp.get_context` which was added in Python 3.4."""
|
them. This depends `mp.get_context` which was added in Python 3.4."""
|
||||||
from .spawn import (
|
from .spawn import (
|
||||||
ENV_VAR_PARALLEL_START,
|
|
||||||
ProcessContext,
|
ProcessContext,
|
||||||
ProcessExitedException,
|
ProcessExitedException,
|
||||||
ProcessRaisedException,
|
ProcessRaisedException,
|
||||||
|
|||||||
@ -9,26 +9,13 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import as_completed, ThreadPoolExecutor
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
|
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ProcessContext",
|
|
||||||
"ProcessException",
|
|
||||||
"ProcessExitedException",
|
|
||||||
"ProcessRaisedException",
|
|
||||||
"spawn",
|
|
||||||
"SpawnContext",
|
|
||||||
"start_processes",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessException(Exception):
|
class ProcessException(Exception):
|
||||||
__slots__ = ["error_index", "error_pid"]
|
__slots__ = ["error_index", "error_pid"]
|
||||||
@ -218,31 +205,12 @@ class SpawnContext(ProcessContext):
|
|||||||
# Currently we only add this API first, we can consider adding it to documentation as
|
# Currently we only add this API first, we can consider adding it to documentation as
|
||||||
# needed in the future.
|
# needed in the future.
|
||||||
def start_processes(
|
def start_processes(
|
||||||
fn,
|
fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"
|
||||||
args=(),
|
|
||||||
nprocs=1,
|
|
||||||
join=True,
|
|
||||||
daemon=False,
|
|
||||||
start_method="spawn",
|
|
||||||
):
|
):
|
||||||
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
|
|
||||||
# this func will start processes in parallel if start_method is 'forkserver'.
|
|
||||||
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
|
|
||||||
# todo: investigate why spawn does not work with threadpool and raises SIGINT
|
|
||||||
if (
|
|
||||||
start_method == "forkserver"
|
|
||||||
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
|
|
||||||
):
|
|
||||||
start_parallel = True
|
|
||||||
else:
|
|
||||||
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
|
|
||||||
start_parallel = False
|
|
||||||
|
|
||||||
mp = multiprocessing.get_context(start_method)
|
mp = multiprocessing.get_context(start_method)
|
||||||
error_files = [None] * nprocs
|
error_files = []
|
||||||
processes = [None] * nprocs
|
processes = []
|
||||||
|
for i in range(nprocs):
|
||||||
def start_process(i):
|
|
||||||
# Each process is assigned a file to write tracebacks to. We
|
# Each process is assigned a file to write tracebacks to. We
|
||||||
# use the file being non-empty to indicate an exception
|
# use the file being non-empty to indicate an exception
|
||||||
# occurred (vs an expected shutdown). Note: this previously
|
# occurred (vs an expected shutdown). Note: this previously
|
||||||
@ -260,21 +228,9 @@ def start_processes(
|
|||||||
daemon=daemon,
|
daemon=daemon,
|
||||||
)
|
)
|
||||||
process.start()
|
process.start()
|
||||||
return i, process, tf.name
|
error_files.append(tf.name)
|
||||||
|
processes.append(process)
|
||||||
|
|
||||||
if not start_parallel:
|
|
||||||
for i in range(nprocs):
|
|
||||||
idx, process, tf_name = start_process(i)
|
|
||||||
error_files[idx] = tf_name
|
|
||||||
processes[idx] = process
|
|
||||||
else:
|
|
||||||
with ThreadPoolExecutor(max_workers=nprocs) as executor:
|
|
||||||
futures = [executor.submit(start_process, i) for i in range(nprocs)]
|
|
||||||
for fut in as_completed(futures):
|
|
||||||
idx, process, tf_name = fut.result()
|
|
||||||
# idx and process rank needs to be the same.
|
|
||||||
error_files[idx] = tf_name
|
|
||||||
processes[idx] = process
|
|
||||||
context = ProcessContext(processes, error_files)
|
context = ProcessContext(processes, error_files)
|
||||||
if not join:
|
if not join:
|
||||||
return context
|
return context
|
||||||
|
|||||||
Reference in New Issue
Block a user