mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Allow mp.start_processes to create processes in parallel (#133707)"
This reverts commit 3546628a2a167ace6060737eeccf8ee8fd87ddc0.
Reverted https://github.com/pytorch/pytorch/pull/133707 on behalf of https://github.com/ZainRizvi due to sorry but trunk has been consistently broken since this PR was merged. See: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10529617600/job/29191757055) [HUD commit link](3546628a2a
) ([comment](https://github.com/pytorch/pytorch/pull/133707#issuecomment-2310709523))
This commit is contained in:
@ -8,14 +8,9 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
def _test_success_func(i):
|
||||
pass
|
||||
@ -92,7 +87,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method):
|
||||
# Kill self. This should take down the child processes as well.
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
class _TestMultiProcessing(TestCase):
|
||||
class _TestMultiProcessing:
|
||||
start_method = None
|
||||
|
||||
def test_success(self):
|
||||
@ -194,11 +189,10 @@ class _TestMultiProcessing(TestCase):
|
||||
self.assertLess(time.time() - start, nested_child_sleep / 2)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
"Disabled for environments that don't support the spawn start method")
|
||||
class SpawnTest(_TestMultiProcessing):
|
||||
class SpawnTest(TestCase, _TestMultiProcessing):
|
||||
start_method = 'spawn'
|
||||
|
||||
def test_exception_raises(self):
|
||||
@ -222,103 +216,10 @@ class SpawnTest(_TestMultiProcessing):
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ForkTest(_TestMultiProcessing):
|
||||
class ForkTest(TestCase, _TestMultiProcessing):
|
||||
start_method = 'fork'
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ForkServerTest(_TestMultiProcessing):
|
||||
start_method = 'forkserver'
|
||||
|
||||
|
||||
class _ParallelTest:
|
||||
orig_paralell_env_val = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
if self.orig_paralell_env_val is None:
|
||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
||||
else:
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
"Disabled for environments that don't support the spawn start method")
|
||||
class ParallelSpawnShouldFallbackAndWorkTest(SpawnTest, _ParallelTest):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ParallelForkShouldFallbackAndWorkTest(ForkTest, _ParallelTest):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ParallelForkServerShouldWorkTest(ForkServerTest, _ParallelTest):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ParallelForkServerPerfTest(TestCase):
|
||||
def test_forkserver_perf(self):
|
||||
start_method = 'forkserver'
|
||||
expensive = Expensive()
|
||||
nprocs = 6
|
||||
orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
|
||||
|
||||
# test the non parallel case
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
|
||||
start = time.perf_counter()
|
||||
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
|
||||
elapsed = time.perf_counter() - start
|
||||
# the time should be at least 6x the sleep time
|
||||
self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
|
||||
|
||||
# test the parallel case
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
|
||||
start = time.perf_counter()
|
||||
mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
# the time should be at most 1x the sleep time + small overhead
|
||||
self.assertLess(elapsed, Expensive.SLEEP_SECS + 10)
|
||||
|
||||
if orig_paralell_env_val is None:
|
||||
del os.environ[mp.ENV_VAR_PARALLEL_START]
|
||||
else:
|
||||
os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
|
||||
|
||||
|
||||
class Expensive:
|
||||
SLEEP_SECS = 10
|
||||
# Simulate startup overhead such as large imports
|
||||
time.sleep(SLEEP_SECS)
|
||||
|
||||
def __init__(self):
|
||||
self.config: str = "*" * 1000000
|
||||
|
||||
def my_call(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class ErrorTest(TestCase):
|
||||
def test_errors_pickleable(self):
|
||||
for error in (
|
||||
|
Reference in New Issue
Block a user