mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
@ -6,8 +5,13 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
|
||||
import torch.multiprocessing as mp
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
run_tests,
|
||||
IS_WINDOWS,
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
)
|
||||
|
||||
|
||||
def test_success_func(i):
|
||||
@ -85,6 +89,7 @@ def test_nested(i, pids_queue, nested_child_sleep, start_method):
|
||||
# Kill self. This should take down the child processes as well.
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
|
||||
class _TestMultiProcessing(object):
|
||||
start_method = None
|
||||
|
||||
@ -92,7 +97,9 @@ class _TestMultiProcessing(object):
|
||||
mp.start_processes(test_success_func, nprocs=2, start_method=self.start_method)
|
||||
|
||||
def test_success_non_blocking(self):
|
||||
mp_context = mp.start_processes(test_success_func, nprocs=2, join=False, start_method=self.start_method)
|
||||
mp_context = mp.start_processes(
|
||||
test_success_func, nprocs=2, join=False, start_method=self.start_method
|
||||
)
|
||||
|
||||
# After all processes (nproc=2) have joined it must return True
|
||||
mp_context.join(timeout=None)
|
||||
@ -102,7 +109,12 @@ class _TestMultiProcessing(object):
|
||||
def test_first_argument_index(self):
|
||||
context = mp.get_context(self.start_method)
|
||||
queue = context.SimpleQueue()
|
||||
mp.start_processes(test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method)
|
||||
mp.start_processes(
|
||||
test_success_single_arg_func,
|
||||
args=(queue,),
|
||||
nprocs=2,
|
||||
start_method=self.start_method,
|
||||
)
|
||||
self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))
|
||||
|
||||
def test_exception_single(self):
|
||||
@ -112,14 +124,21 @@ class _TestMultiProcessing(object):
|
||||
Exception,
|
||||
"\nValueError: legitimate exception from process %d$" % i,
|
||||
):
|
||||
mp.start_processes(test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method)
|
||||
mp.start_processes(
|
||||
test_exception_single_func,
|
||||
args=(i,),
|
||||
nprocs=nprocs,
|
||||
start_method=self.start_method,
|
||||
)
|
||||
|
||||
def test_exception_all(self):
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"\nValueError: legitimate exception from process (0|1)$",
|
||||
):
|
||||
mp.start_processes(test_exception_all_func, nprocs=2, start_method=self.start_method)
|
||||
mp.start_processes(
|
||||
test_exception_all_func, nprocs=2, start_method=self.start_method
|
||||
)
|
||||
|
||||
def test_terminate_signal(self):
|
||||
# SIGABRT is aliased with SIGIOT
|
||||
@ -134,7 +153,9 @@ class _TestMultiProcessing(object):
|
||||
message = "process 0 terminated with exit code 22"
|
||||
|
||||
with self.assertRaisesRegex(Exception, message):
|
||||
mp.start_processes(test_terminate_signal_func, nprocs=2, start_method=self.start_method)
|
||||
mp.start_processes(
|
||||
test_terminate_signal_func, nprocs=2, start_method=self.start_method
|
||||
)
|
||||
|
||||
def test_terminate_exit(self):
|
||||
exitcode = 123
|
||||
@ -142,7 +163,12 @@ class _TestMultiProcessing(object):
|
||||
Exception,
|
||||
"process 0 terminated with exit code %d" % exitcode,
|
||||
):
|
||||
mp.start_processes(test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
|
||||
mp.start_processes(
|
||||
test_terminate_exit_func,
|
||||
args=(exitcode,),
|
||||
nprocs=2,
|
||||
start_method=self.start_method,
|
||||
)
|
||||
|
||||
def test_success_first_then_exception(self):
|
||||
exitcode = 123
|
||||
@ -150,7 +176,12 @@ class _TestMultiProcessing(object):
|
||||
Exception,
|
||||
"ValueError: legitimate exception",
|
||||
):
|
||||
mp.start_processes(test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
|
||||
mp.start_processes(
|
||||
test_success_first_then_exception_func,
|
||||
args=(exitcode,),
|
||||
nprocs=2,
|
||||
start_method=self.start_method,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.platform != "linux",
|
||||
@ -187,11 +218,13 @@ class _TestMultiProcessing(object):
|
||||
self.assertLess(time.time() - start, nested_child_sleep / 2)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
NO_MULTIPROCESSING_SPAWN,
|
||||
"Disabled for environments that don't support the spawn start method")
|
||||
"Disabled for environments that don't support the spawn start method",
|
||||
)
|
||||
class SpawnTest(TestCase, _TestMultiProcessing):
|
||||
start_method = 'spawn'
|
||||
start_method = "spawn"
|
||||
|
||||
def test_exception_raises(self):
|
||||
with self.assertRaises(mp.ProcessRaisedException):
|
||||
@ -215,7 +248,8 @@ class SpawnTest(TestCase, _TestMultiProcessing):
|
||||
"Fork is only available on Unix",
|
||||
)
|
||||
class ForkTest(TestCase, _TestMultiProcessing):
|
||||
start_method = 'fork'
|
||||
start_method = "fork"
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user