Revert "Allow mp.start_processes to create processes in parallel (#133707)"

This reverts commit 3546628a2a167ace6060737eeccf8ee8fd87ddc0.

Reverted https://github.com/pytorch/pytorch/pull/133707 on behalf of https://github.com/ZainRizvi due to sorry but trunk has been consistently broken since this PR was merged. See: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10529617600/job/29191757055) [HUD commit link](3546628a2a) ([comment](https://github.com/pytorch/pytorch/pull/133707#issuecomment-2310709523))
This commit is contained in:
PyTorch MergeBot
2024-08-26 17:31:10 +00:00
parent d0ac5d55ba
commit adcce538b7
4 changed files with 144 additions and 311 deletions

View File

@ -8,14 +8,9 @@ import sys
import time
import unittest
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import (
IS_WINDOWS,
NO_MULTIPROCESSING_SPAWN,
run_tests,
TestCase,
)
def _test_success_func(i):
pass
@ -92,7 +87,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method):
# Kill self. This should take down the child processes as well.
os.kill(os.getpid(), signal.SIGTERM)
class _TestMultiProcessing(TestCase):
class _TestMultiProcessing:
start_method = None
def test_success(self):
@ -194,11 +189,10 @@ class _TestMultiProcessing(TestCase):
self.assertLess(time.time() - start, nested_child_sleep / 2)
time.sleep(0.1)
@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
class SpawnTest(_TestMultiProcessing):
class SpawnTest(TestCase, _TestMultiProcessing):
start_method = 'spawn'
def test_exception_raises(self):
@ -222,103 +216,10 @@ class SpawnTest(_TestMultiProcessing):
IS_WINDOWS,
"Fork is only available on Unix",
)
class ForkTest(_TestMultiProcessing):
class ForkTest(TestCase, _TestMultiProcessing):
start_method = 'fork'
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ForkServerTest(_TestMultiProcessing):
start_method = 'forkserver'
class _ParallelTest:
orig_paralell_env_val = None
def setUp(self):
super().setUp()
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
def tearDown(self):
super().tearDown()
if self.orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
class ParallelSpawnShouldFallbackAndWorkTest(SpawnTest, _ParallelTest):
pass
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkShouldFallbackAndWorkTest(ForkTest, _ParallelTest):
pass
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerShouldWorkTest(ForkServerTest, _ParallelTest):
pass
@unittest.skipIf(
IS_WINDOWS,
"Fork is only available on Unix",
)
class ParallelForkServerPerfTest(TestCase):
def test_forkserver_perf(self):
start_method = 'forkserver'
expensive = Expensive()
nprocs = 6
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
# test the non parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the time should be at least 6x the sleep time
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
# test the parallel case
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
start = time.perf_counter()
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
elapsed = time.perf_counter() - start
# the time should be at most 1x the sleep time + small overhead
self.assertLess(elapsed, Expensive.SLEEP_SECS + 10)
if orig_paralell_env_val is None:
del os.environ[mp.ENV_VAR_PARALLEL_START]
else:
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
class Expensive:
SLEEP_SECS = 10
# Simulate startup overhead such as large imports
time.sleep(SLEEP_SECS)
def __init__(self):
self.config: str = "*" * 1000000
def my_call(self, *args):
pass
class ErrorTest(TestCase):
def test_errors_pickleable(self):
for error in (