[8/N] Remove c10d/ddp fork tests. (#63454)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63454

Continuation of https://github.com/pytorch/pytorch/pull/63443, this
PR removes all fork tests from torch.distributed.
ghstack-source-id: 136285511

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D30387872

fbshipit-source-id: f6d6313db126ae7b95b86f78a1e0726887c5c513
This commit is contained in:
Pritam Damania
2021-08-20 12:09:49 -07:00
committed by Facebook GitHub Bot
parent 71da114412
commit 2d671ca41b
23 changed files with 348 additions and 533 deletions

View File

@ -19,7 +19,6 @@ fi
python tools/download_mnist.py --quiet -d test/cpp/api/mnist python tools/download_mnist.py --quiet -d test/cpp/api/mnist
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api
time python test/run_test.py --verbose -i distributed/test_jit_c10d time python test/run_test.py --verbose -i distributed/test_jit_c10d
time python test/run_test.py --verbose -i distributed/test_distributed_fork
time python test/run_test.py --verbose -i distributed/test_c10d_common time python test/run_test.py --verbose -i distributed/test_c10d_common
time python test/run_test.py --verbose -i distributed/test_c10d_gloo time python test/run_test.py --verbose -i distributed/test_c10d_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_nccl time python test/run_test.py --verbose -i distributed/test_c10d_nccl

View File

@ -21,8 +21,14 @@ from torch.testing._internal.common_distributed import (
requires_nccl, requires_nccl,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
) )
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if TEST_WITH_DEV_DBG_ASAN:
print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
sys.exit(0)
def gpus_for_rank(world_size): def gpus_for_rank(world_size):
visible_devices = list(range(torch.cuda.device_count())) visible_devices = list(range(torch.cuda.device_count()))
@ -57,7 +63,7 @@ class TestDdpCommHook(nn.Module):
class DistributedDataParallelCommHookTest(MultiProcessTestCase): class DistributedDataParallelCommHookTest(MultiProcessTestCase):
def setUp(self): def setUp(self):
super(DistributedDataParallelCommHookTest, self).setUp() super(DistributedDataParallelCommHookTest, self).setUp()
self._fork_processes() self._spawn_processes()
def tearDown(self): def tearDown(self):
try: try:

View File

@ -37,7 +37,6 @@ from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.rpc.backend_registry import BackendType from torch.distributed.rpc.backend_registry import BackendType
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
sandcastle_skip_if, sandcastle_skip_if,
) )
@ -406,19 +405,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual((100, 100), return_value.shape) self.assertEqual((100, 100), return_value.shape)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_dummy_compute_c10d(self): def test_dummy_compute_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute) self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_dummy_compute_etcd(self): def test_dummy_compute_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute) self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_dummy_compute_etcd_v2(self): def test_dummy_compute_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute)
@ -431,19 +430,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertIsNone(res.return_values[1]) self.assertIsNone(res.return_values[1])
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_happy_function_c10d(self): def test_run_happy_function_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function) self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_happy_function_etcd(self): def test_run_happy_function_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function) self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_happy_function_etcd_v2(self): def test_run_happy_function_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_happy_function) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_happy_function)
@ -465,13 +464,13 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertIsNone(res.return_values[0]) self.assertIsNone(res.return_values[0])
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_check_master_addr_port_override_etcd(self): def test_check_master_addr_port_override_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.check_master_addr_port_override) self.run_test_with_backend(backend="etcd", test_to_run=self.check_master_addr_port_override)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_check_master_addr_port_override_etcd_v2(self): def test_check_master_addr_port_override_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.check_master_addr_port_override) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.check_master_addr_port_override)
@ -484,7 +483,7 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertFalse(res.is_failed()) self.assertFalse(res.is_failed())
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_check_env_function_etcd(self): def test_run_check_env_function_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_check_env_function) self.run_test_with_backend(backend="etcd", test_to_run=self.run_check_env_function)
@ -497,19 +496,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual("foo", res.return_values[1]) self.assertEqual("foo", res.return_values[1])
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_function_with_return_value_c10d(self): def test_run_function_with_return_value_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.run_function_with_return_value) self.run_test_with_backend(backend="c10d", test_to_run=self.run_function_with_return_value)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_function_with_return_value_etcd(self): def test_run_function_with_return_value_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_function_with_return_value) self.run_test_with_backend(backend="etcd", test_to_run=self.run_function_with_return_value)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_function_with_return_value_etcd_v2(self): def test_run_function_with_return_value_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_function_with_return_value) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_function_with_return_value)
@ -520,19 +519,19 @@ class LocalElasticAgentTest(unittest.TestCase):
# _dist_sum internally checks that the sum computed is valid # _dist_sum internally checks that the sum computed is valid
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_simple_dist_sum_c10d(self): def test_simple_dist_sum_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum) self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_simple_dist_sum_etcd(self): def test_simple_dist_sum_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum) self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_simple_dist_sum_etcd_v2(self): def test_simple_dist_sum_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum)
@ -556,19 +555,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertSetEqual(set(range(4 + 4)), ranks) self.assertSetEqual(set(range(4 + 4)), ranks)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_distributed_sum_homogeneous_c10d(self): def test_run_distributed_sum_homogeneous_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_homogeneous) self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_homogeneous)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_distributed_sum_homogeneous_etcd(self): def test_run_distributed_sum_homogeneous_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_homogeneous) self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_homogeneous)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_distributed_sum_homogeneous_etcd_v2(self): def test_run_distributed_sum_homogeneous_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous)
@ -596,19 +595,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertSetEqual(set(range(1 + 2 + 3)), ranks) self.assertSetEqual(set(range(1 + 2 + 3)), ranks)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_distributed_sum_heterogeneous_c10d(self): def test_run_distributed_sum_heterogeneous_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous) self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_distributed_sum_heterogeneous_etcd(self): def test_run_distributed_sum_heterogeneous_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous) self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_distributed_sum_heterogeneous_etcd_v2(self): def test_run_distributed_sum_heterogeneous_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous)
@ -636,19 +635,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp) self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_sad_function_c10d(self): def test_run_sad_function_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function) self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_sad_function_etcd(self): def test_run_sad_function_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function) self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_sad_function_etcd_v2(self): def test_run_sad_function_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function)
@ -668,19 +667,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertTrue(agent._total_execution_time > 0) self.assertTrue(agent._total_execution_time > 0)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_bipolar_function_c10d(self): def test_run_bipolar_function_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.run_bipolar_function) self.run_test_with_backend(backend="c10d", test_to_run=self.run_bipolar_function)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_bipolar_function_etcd(self): def test_run_bipolar_function_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.run_bipolar_function) self.run_test_with_backend(backend="etcd", test_to_run=self.run_bipolar_function)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_run_bipolar_function_etcd_v2(self): def test_run_bipolar_function_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_bipolar_function) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_bipolar_function)
@ -711,13 +710,13 @@ class LocalElasticAgentTest(unittest.TestCase):
) )
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_correct_rank_assignment_heterogeneous_etcd(self): def test_correct_rank_assignment_heterogeneous_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous) self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_correct_rank_assignment_heterogeneous_etcd_v2(self): def test_correct_rank_assignment_heterogeneous_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous)
@ -744,13 +743,13 @@ class LocalElasticAgentTest(unittest.TestCase):
) )
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_correct_rank_assignment_homogeneous_etcd(self): def test_correct_rank_assignment_homogeneous_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous) self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_correct_rank_assignment_homogeneous_etcd_v2(self): def test_correct_rank_assignment_homogeneous_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous)
@ -852,13 +851,13 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual(0, p.exitcode) self.assertEqual(0, p.exitcode)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_double_agent_fault_tolerance_etcd(self): def test_double_agent_fault_tolerance_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_fault_tolerance) self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_fault_tolerance)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_double_agent_fault_tolerance_etcd_v2(self): def test_double_agent_fault_tolerance_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance)
@ -905,19 +904,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual(-signal.SIGKILL, p.exitcode) self.assertEqual(-signal.SIGKILL, p.exitcode)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_double_agent_elastic_c10d(self): def test_double_agent_elastic_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.double_agent_elastic) self.run_test_with_backend(backend="c10d", test_to_run=self.double_agent_elastic)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_double_agent_elastic_etcd(self): def test_double_agent_elastic_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_elastic) self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_elastic)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_double_agent_elastic_etcd_v2(self): def test_double_agent_elastic_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_elastic) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_elastic)
@ -955,19 +954,19 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual([f"{msg} from worker"], list(master_retvals.values())) self.assertEqual([f"{msg} from worker"], list(master_retvals.values()))
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_torch_rpc_c10d(self): def test_torch_rpc_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc) self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_torch_rpc_etcd(self): def test_torch_rpc_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc) self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_torch_rpc_etcd_v2(self): def test_torch_rpc_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc)
@ -993,13 +992,13 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual(rank, output) self.assertEqual(rank, output)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_workers_drift_success_etcd(self): def test_workers_drift_success_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_success) self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_success)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_workers_drift_success_etcd_v2(self): def test_workers_drift_success_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_success) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_success)
@ -1024,13 +1023,13 @@ class LocalElasticAgentTest(unittest.TestCase):
self.assertEqual(rank, output) self.assertEqual(rank, output)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_workers_drift_fail_etcd(self): def test_workers_drift_fail_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail) self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_workers_drift_fail_etcd_v2(self): def test_workers_drift_fail_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_fail) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_fail)
@ -1047,19 +1046,19 @@ class LocalElasticAgentTest(unittest.TestCase):
barrier_mock.assert_called_once() barrier_mock.assert_called_once()
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_barrier_failed_c10d(self): def test_barrier_failed_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed) self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_barrier_failed_etcd(self): def test_barrier_failed_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed) self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_barrier_failed_etcd_v2(self): def test_barrier_failed_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed)
@ -1081,19 +1080,19 @@ class LocalElasticAgentTest(unittest.TestCase):
pcontext_mock.close.assert_called_once() pcontext_mock.close.assert_called_once()
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_shutdown_called_c10d(self): def test_shutdown_called_c10d(self):
self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called) self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_shutdown_called_etcd(self): def test_shutdown_called_etcd(self):
self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called) self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_shutdown_called_etcd_v2(self): def test_shutdown_called_etcd_v2(self):
self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called) self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called)

View File

@ -35,8 +35,8 @@ from torch.distributed.elastic.multiprocessing.errors.error_handler import _writ
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
NO_MULTIPROCESSING_SPAWN, NO_MULTIPROCESSING_SPAWN,
TEST_WITH_ASAN, TEST_WITH_ASAN,
TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN, TEST_WITH_TSAN,
TEST_WITH_DEV_DBG_ASAN,
IS_IN_CI, IS_IN_CI,
IS_WINDOWS, IS_WINDOWS,
IS_MACOS, IS_MACOS,
@ -223,15 +223,11 @@ def start_processes_zombie_test(
# tests incompatible with tsan or asan # tests incompatible with tsan or asan
if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS): if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
class StartProcessesTest(unittest.TestCase): class StartProcessesTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
self._start_methods = ["spawn"]
if NO_MULTIPROCESSING_SPAWN: # python 2.7 doesn't have spawn
self._start_methods = ["fork"]
else:
self._start_methods = ["fork", "spawn"]
def tearDown(self): def tearDown(self):
shutil.rmtree(self.test_dir) shutil.rmtree(self.test_dir)
@ -317,7 +313,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
args={0: (1,)}, args={0: (1,)},
envs={0: {}}, envs={0: {}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
start_method="fork", start_method="spawn",
) )
self.assertIsNone(pc.wait(timeout=0.1, period=0.01)) self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
@ -332,7 +328,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
args={0: (1,)}, args={0: (1,)},
envs={0: {}}, envs={0: {}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
start_method="fork", start_method="spawn",
) )
pids = pc.pids() pids = pc.pids()
@ -387,7 +383,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
self.assertEqual({0: None, 1: None}, results.return_values) self.assertEqual({0: None, 1: None}, results.return_values)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan"
) )
def test_function_large_ret_val(self): def test_function_large_ret_val(self):
# python multiprocessing.queue module uses pipes and actually PipedQueues # python multiprocessing.queue module uses pipes and actually PipedQueues
@ -549,7 +545,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS): if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
class StartProcessesListTest(StartProcessesTest): class StartProcessesListTest(StartProcessesTest):
######################################## ########################################
# start_processes as binary tests # start_processes as binary tests
@ -630,7 +626,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
args={0: ("hello",), 1: ("world",)}, args={0: ("hello",), 1: ("world",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=self.log_dir(), log_dir=self.log_dir(),
start_method="fork", start_method="spawn",
redirects={0: Std.ERR, 1: Std.NONE}, redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR}, tee={0: Std.OUT, 1: Std.ERR},
) )
@ -647,7 +643,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS_IN_CI): if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI):
class StartProcessesNotCITest(StartProcessesTest): class StartProcessesNotCITest(StartProcessesTest):
def test_wrap_bad(self): def test_wrap_bad(self):
none = "" none = ""
@ -697,8 +693,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS
failure = results.failures[0] failure = results.failures[0]
self.assertNotEqual(signal.SIGSEGV, failure.exitcode) self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
if TEST_WITH_ASAN: if TEST_WITH_ASAN or TEST_WITH_TSAN:
# ASAN exit code is 1. # ASAN/TSAN exit code is 1.
self.assertEqual("<N/A>", failure.signal_name()) self.assertEqual("<N/A>", failure.signal_name())
else: else:
self.assertEqual("SIGSEGV", failure.signal_name()) self.assertEqual("SIGSEGV", failure.signal_name())
@ -714,7 +710,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS
args={0: ("hello",), 1: ("world",)}, args={0: ("hello",), 1: ("world",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
log_dir=log_dir, log_dir=log_dir,
start_method="fork", start_method="spawn",
redirects={0: Std.ERR, 1: Std.NONE}, redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR}, tee={0: Std.OUT, 1: Std.ERR},
) )

View File

@ -13,7 +13,6 @@ from torch.distributed.elastic.multiprocessing.errors import (
record, record,
) )
from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error
from torch.testing._internal.common_utils import TEST_WITH_TSAN
class SentinelError(Exception): class SentinelError(Exception):
@ -45,10 +44,6 @@ def read_resource_file(resource_file: str) -> str:
return "".join(fp.readlines()) return "".join(fp.readlines())
if TEST_WITH_TSAN:
print("test incompatible with tsan", file=sys.stderr)
sys.exit(0)
class ApiTest(unittest.TestCase): class ApiTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)

View File

@ -15,7 +15,6 @@ import torch.distributed.elastic.timer as timer
import torch.multiprocessing as torch_mp import torch.multiprocessing as torch_mp
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
run_tests, run_tests,
IS_WINDOWS, IS_WINDOWS,
IS_MACOS, IS_MACOS,
@ -55,7 +54,7 @@ if not (IS_WINDOWS or IS_MACOS):
unittest. As of now this will SIGSEGV. unittest. As of now this will SIGSEGV.
""" """
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
def test_torch_mp_example(self): def test_torch_mp_example(self):
# in practice set the max_interval to a larger value (e.g. 60 seconds) # in practice set the max_interval to a larger value (e.g. 60 seconds)
mp_queue = mp.get_context("spawn").Queue() mp_queue = mp.get_context("spawn").Queue()
@ -80,18 +79,14 @@ if not (IS_WINDOWS or IS_MACOS):
server.stop() server.stop()
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
def test_example_start_method_spawn(self): def test_example_start_method_spawn(self):
self._run_example_with(start_method="spawn") self._run_example_with(start_method="spawn")
# @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible") # @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
# def test_example_start_method_forkserver(self): # def test_example_start_method_forkserver(self):
# self._run_example_with(start_method="forkserver") # self._run_example_with(start_method="forkserver")
@sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
def test_example_start_method_fork(self):
self._run_example_with(start_method="fork")
def _run_example_with(self, start_method): def _run_example_with(self, start_method):
spawn_ctx = mp.get_context(start_method) spawn_ctx = mp.get_context(start_method)
mp_queue = spawn_ctx.Queue() mp_queue = spawn_ctx.Queue()

View File

@ -13,19 +13,28 @@ import torch.distributed.elastic.timer as timer
from torch.distributed.elastic.timer.api import TimerRequest from torch.distributed.elastic.timer.api import TimerRequest
from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_TSAN,
run_tests, run_tests,
IS_WINDOWS, IS_WINDOWS,
IS_MACOS, IS_MACOS,
sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN,
) )
# timer is not supported on windows or macos # timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS): if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
# func2 should time out
def func2(n, mp_queue):
if mp_queue is not None:
timer.configure(timer.LocalTimerClient(mp_queue))
if n > 0:
with timer.expires(after=0.1):
func2(n - 1, None)
time.sleep(0.2)
class LocalTimerTest(unittest.TestCase): class LocalTimerTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.mp_queue = mp.Queue() self.ctx = mp.get_context('spawn')
self.mp_queue = self.ctx.Queue()
self.max_interval = 0.01 self.max_interval = 0.01
self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
self.server.start() self.server.start()
@ -62,7 +71,6 @@ if not (IS_WINDOWS or IS_MACOS):
with timer.expires(after=0.5): with timer.expires(after=0.5):
time.sleep(0.1) time.sleep(0.1)
@sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
def test_get_timer_recursive(self): def test_get_timer_recursive(self):
""" """
If a function acquires a countdown timer with default scope, If a function acquires a countdown timer with default scope,
@ -82,14 +90,7 @@ if not (IS_WINDOWS or IS_MACOS):
func(4) func(4)
# func2 should time out p = self.ctx.Process(target=func2, args=(2, self.mp_queue))
def func2(n):
if n > 0:
with timer.expires(after=0.1):
func2(n - 1)
time.sleep(0.2)
p = mp.Process(target=func2, args=(2,))
p.start() p.start()
p.join() p.join()
self.assertEqual(-signal.SIGKILL, p.exitcode) self.assertEqual(-signal.SIGKILL, p.exitcode)
@ -102,7 +103,6 @@ if not (IS_WINDOWS or IS_MACOS):
with timer.expires(after=timeout): with timer.expires(after=timeout):
time.sleep(duration) time.sleep(duration)
@sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
def test_timer(self): def test_timer(self):
timeout = 0.1 timeout = 0.1
duration = 1 duration = 1
@ -124,7 +124,7 @@ if not (IS_WINDOWS or IS_MACOS):
# timer is not supported on windows or macos # timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS): if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
class MultiprocessingRequestQueueTest(unittest.TestCase): class MultiprocessingRequestQueueTest(unittest.TestCase):
def test_get(self): def test_get(self):
mp_queue = mp.Queue() mp_queue = mp.Queue()
@ -183,7 +183,7 @@ if not (IS_WINDOWS or IS_MACOS):
# timer is not supported on windows or macos # timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS): if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
class LocalTimerServerTest(unittest.TestCase): class LocalTimerServerTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.mp_queue = mp.Queue() self.mp_queue = mp.Queue()
@ -193,7 +193,6 @@ if not (IS_WINDOWS or IS_MACOS):
def tearDown(self): def tearDown(self):
self.server.stop() self.server.stop()
@sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
def test_watchdog_call_count(self): def test_watchdog_call_count(self):
""" """
checks that the watchdog function ran wait/interval +- 1 times checks that the watchdog function ran wait/interval +- 1 times
@ -226,7 +225,6 @@ if not (IS_WINDOWS or IS_MACOS):
def _release_timer(self, pid, scope): def _release_timer(self, pid, scope):
return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1) return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1)
@sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
@mock.patch("os.kill") @mock.patch("os.kill")
def test_expired_timers(self, mock_os_kill): def test_expired_timers(self, mock_os_kill):
""" """

View File

@ -31,7 +31,6 @@ from torch.distributed.launcher.api import (
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
sandcastle_skip_if, sandcastle_skip_if,
) )
@ -117,7 +116,7 @@ class ElasticLaunchTest(unittest.TestCase):
rdzv_endpoint=endpoint, rdzv_endpoint=endpoint,
monitor_interval=1, monitor_interval=1,
rdzv_backend=rdzv_backend, rdzv_backend=rdzv_backend,
start_method="fork", start_method="spawn",
max_restarts=0, max_restarts=0,
rdzv_configs=rdzv_configs, rdzv_configs=rdzv_configs,
) )
@ -128,7 +127,7 @@ class ElasticLaunchTest(unittest.TestCase):
) )
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_script_python(self): def test_launch_script_python(self):
nnodes = 1 nnodes = 1
@ -145,7 +144,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.check_works_ran(world_size) self.check_works_ran(world_size)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_script_python_local_rank_transfer(self): def test_launch_script_python_local_rank_transfer(self):
nnodes = 1 nnodes = 1
@ -162,7 +161,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.check_works_ran(world_size) self.check_works_ran(world_size)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_script_bash(self): def test_launch_script_bash(self):
nnodes = 1 nnodes = 1
@ -177,7 +176,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.check_works_ran(world_size) self.check_works_ran(world_size)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_function(self): def test_launch_function(self):
nnodes = 1 nnodes = 1
@ -193,7 +192,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.assertEqual(expected_res, actual_res) self.assertEqual(expected_res, actual_res)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_dist_sum_with_static_rdzv(self): def test_launch_dist_sum_with_static_rdzv(self):
nnodes = 1 nnodes = 1
@ -224,7 +223,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.assertEqual(expected_res, actual_res) self.assertEqual(expected_res, actual_res)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_elastic(self): def test_launch_elastic(self):
nproc_per_node = 4 nproc_per_node = 4

View File

@ -15,7 +15,6 @@ import torch.distributed.launch as launch
from torch.distributed.elastic.utils import get_socket_with_port from torch.distributed.elastic.utils import get_socket_with_port
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
sandcastle_skip_if, sandcastle_skip_if,
) )
@ -36,7 +35,7 @@ class LaunchTest(unittest.TestCase):
shutil.rmtree(self.test_dir) shutil.rmtree(self.test_dir)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_without_env(self): def test_launch_without_env(self):
nnodes = 1 nnodes = 1
@ -49,7 +48,7 @@ class LaunchTest(unittest.TestCase):
f"--nnodes={nnodes}", f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--master_addr=localhost", "--master_addr=localhost",
f"--master_port={master_port}", f"--master_port={master_port}",
"--node_rank=0", "--node_rank=0",
@ -58,7 +57,7 @@ class LaunchTest(unittest.TestCase):
launch.main(args) launch.main(args)
@sandcastle_skip_if( @sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan" TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
) )
def test_launch_with_env(self): def test_launch_with_env(self):
nnodes = 1 nnodes = 1
@ -71,7 +70,7 @@ class LaunchTest(unittest.TestCase):
f"--nnodes={nnodes}", f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--master_addr=localhost", "--master_addr=localhost",
f"--master_port={master_port}", f"--master_port={master_port}",
"--node_rank=0", "--node_rank=0",

View File

@ -23,7 +23,6 @@ from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port from torch.distributed.elastic.utils import get_socket_with_port
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
sandcastle_skip_if, sandcastle_skip_if,
) )
@ -100,7 +99,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]
@ -123,7 +122,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--nnodes={nnodes}", f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--master_addr=localhost", "--master_addr=localhost",
f"--master_port={master_port}", f"--master_port={master_port}",
"--node_rank=0", "--node_rank=0",
@ -138,7 +137,7 @@ class ElasticLaunchTest(unittest.TestCase):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
) )
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_user_script_bash(self): def test_launch_user_script_bash(self):
run_id = str(uuid.uuid4().int) run_id = str(uuid.uuid4().int)
nnodes = 1 nnodes = 1
@ -151,7 +150,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--no_python", "--no_python",
] ]
@ -169,7 +168,7 @@ class ElasticLaunchTest(unittest.TestCase):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
) )
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_user_script_default_nproc(self): def test_launch_user_script_default_nproc(self):
run_id = str(uuid.uuid4().int) run_id = str(uuid.uuid4().int)
nnodes = 1 nnodes = 1
@ -180,7 +179,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--no_python", "--no_python",
] ]
@ -198,7 +197,7 @@ class ElasticLaunchTest(unittest.TestCase):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
) )
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_with_env_vars(self): def test_launch_with_env_vars(self):
run_id = str(uuid.uuid4().int) run_id = str(uuid.uuid4().int)
nnodes = 1 nnodes = 1
@ -211,7 +210,7 @@ class ElasticLaunchTest(unittest.TestCase):
os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
os.environ["PET_RDZV_ID"] = run_id os.environ["PET_RDZV_ID"] = run_id
os.environ["PET_MONITOR_INTERVAL"] = "1" os.environ["PET_MONITOR_INTERVAL"] = "1"
os.environ["PET_START_METHOD"] = "fork" os.environ["PET_START_METHOD"] = "spawn"
os.environ["PET_NO_PYTHON"] = "1" os.environ["PET_NO_PYTHON"] = "1"
script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
@ -241,7 +240,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--no_python", "--no_python",
] ]
@ -256,27 +255,27 @@ class ElasticLaunchTest(unittest.TestCase):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
) )
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_nproc_launch_auto_configurations(self): def test_nproc_launch_auto_configurations(self):
self._test_nproc_launch_configuration("auto", os.cpu_count()) self._test_nproc_launch_configuration("auto", os.cpu_count())
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_nproc_launch_number_configurations(self): def test_nproc_launch_number_configurations(self):
self._test_nproc_launch_configuration("4", 4) self._test_nproc_launch_configuration("4", 4)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_nproc_launch_unknown_configurations(self): def test_nproc_launch_unknown_configurations(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self._test_nproc_launch_configuration("unknown", 4) self._test_nproc_launch_configuration("unknown", 4)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
@patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.is_available", return_value=True)
@patch("torch.cuda.device_count", return_value=3) @patch("torch.cuda.device_count", return_value=3)
def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("auto", 3)
self._test_nproc_launch_configuration("gpu", 3) self._test_nproc_launch_configuration("gpu", 3)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_elastic(self): def test_launch_elastic(self):
run_id = str(uuid.uuid4().int) run_id = str(uuid.uuid4().int)
min_nodes = 1 min_nodes = 1
@ -291,7 +290,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]
@ -304,7 +303,7 @@ class ElasticLaunchTest(unittest.TestCase):
) )
@mock.patch("torch.distributed.elastic.events.record") @mock.patch("torch.distributed.elastic.events.record")
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_elastic_worker_raise_exception(self, record_mock): def test_launch_elastic_worker_raise_exception(self, record_mock):
""" """
Asserts that when the worker program fails and lancher raieses exception Asserts that when the worker program fails and lancher raieses exception
@ -323,7 +322,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--max_restarts=0", "--max_restarts=0",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
"--fail", "--fail",
] ]
@ -332,7 +331,7 @@ class ElasticLaunchTest(unittest.TestCase):
record_mock.assert_called_once() record_mock.assert_called_once()
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
@mock.patch( @mock.patch(
"torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run" "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run"
) )
@ -354,7 +353,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--max_restarts=0", "--max_restarts=0",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]
@ -364,7 +363,7 @@ class ElasticLaunchTest(unittest.TestCase):
launch.main(args) launch.main(args)
record_mock.assert_called_once() record_mock.assert_called_once()
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_standalone(self): def test_launch_standalone(self):
nnodes = 1 nnodes = 1
nproc_per_node = 4 nproc_per_node = 4
@ -374,7 +373,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--standalone", "--standalone",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]
@ -386,7 +385,7 @@ class ElasticLaunchTest(unittest.TestCase):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
) )
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_run_path(self): def test_launch_run_path(self):
nnodes = 1 nnodes = 1
nproc_per_node = 4 nproc_per_node = 4
@ -396,7 +395,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--nnodes={nnodes}", f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]
@ -408,7 +407,7 @@ class ElasticLaunchTest(unittest.TestCase):
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
) )
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan") @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_launch_elastic_multiple_agents(self): def test_launch_elastic_multiple_agents(self):
run_id = str(uuid.uuid4().int) run_id = str(uuid.uuid4().int)
min_nodes = 1 min_nodes = 1
@ -423,7 +422,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_endpoint={self._etcd_endpoint}",
f"--rdzv_id={run_id}", f"--rdzv_id={run_id}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]
@ -462,7 +461,7 @@ class ElasticLaunchTest(unittest.TestCase):
f"--nnodes={nnodes}", f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
path("bin/test_script.py"), path("bin/test_script.py"),
f"--touch_file_dir={self.test_dir}", f"--touch_file_dir={self.test_dir}",
] ]

View File

@ -28,9 +28,13 @@ from torch.testing._internal.common_utils import (
TestCase, TestCase,
load_tests, load_tests,
run_tests, run_tests,
TEST_WITH_TSAN, TEST_WITH_DEV_DBG_ASAN,
) )
if TEST_WITH_DEV_DBG_ASAN:
print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
sys.exit(0)
# load_tests from common_utils is used to automatically filter tests for # load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings # sharding on sandcastle. This line silences flake warnings
load_tests = load_tests load_tests = load_tests
@ -438,37 +442,31 @@ class AbstractDistributedDataParallelTest(object):
return fut.then(fut_then) return fut.then(fut_then)
# TSAN is not fork-safe since we're forking in a multi-threaded environment class DistributedDataParallelTest(
if not TEST_WITH_TSAN: AbstractDistributedDataParallelTest, MultiProcessTestCase
):
def setUp(self):
super(DistributedDataParallelTest, self).setUp()
self._spawn_processes()
class DistributedDataParallelTest( def test_invalid_powerSGD_state(self):
AbstractDistributedDataParallelTest, MultiProcessTestCase for start_powerSGD_iter, use_error_feedback, warm_start in product(
): [0, 1], [True, False], [True, False]
def setUp(self): ):
super(DistributedDataParallelTest, self).setUp() if not use_error_feedback and not warm_start:
if sys.platform == "win32": continue
self._spawn_processes() with self.assertRaisesRegex(
else: ValueError,
self._fork_processes() "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
"because PowerSGD can only be applied after the first two iterations in DDP.",
def test_invalid_powerSGD_state(self):
for start_powerSGD_iter, use_error_feedback, warm_start in product(
[0, 1], [True, False], [True, False]
): ):
if not use_error_feedback and not warm_start: state = powerSGD.PowerSGDState(
continue process_group=None,
with self.assertRaisesRegex( matrix_approximation_rank=1,
ValueError, start_powerSGD_iter=start_powerSGD_iter,
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " use_error_feedback=use_error_feedback,
"because PowerSGD can only be applied after the first two iterations in DDP.", warm_start=warm_start,
): )
state = powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=start_powerSGD_iter,
use_error_feedback=use_error_feedback,
warm_start=warm_start,
)
class ComputeBucketAssignmentTest(TestCase): class ComputeBucketAssignmentTest(TestCase):
@ -656,49 +654,42 @@ class AbstractCommTest(object):
dist.all_gather_object(obj_list, subgroup_seq, group=subgroup) dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
self.assertEqual(len(set(obj_list)), 1) self.assertEqual(len(set(obj_list)), 1)
class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
super(CommTest, self).setUp()
self._spawn_processes()
# TSAN is not fork-safe since we're forking in a multi-threaded environment def tearDown(self):
if not TEST_WITH_TSAN: super(CommTest, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
class CommTest(AbstractCommTest, MultiProcessTestCase): def test_distributed_debug_mode(self):
def setUp(self): # Default should be off
super(CommTest, self).setUp() default_debug_mode = dist._get_debug_mode()
if sys.platform == "win32": self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF)
self._spawn_processes() mapping = {
else: "OFF": dist._DistributedDebugLevel.OFF,
self._fork_processes() "INFO": dist._DistributedDebugLevel.INFO,
"DETAIL": dist._DistributedDebugLevel.DETAIL,
}
invalid_debug_modes = ["foo", 0, 1, -1]
def tearDown(self): for mode in mapping.keys():
super(CommTest, self).tearDown() os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
try: set_debug_mode = dist._get_debug_mode()
os.remove(self.file_name) self.assertEqual(
except OSError: set_debug_mode,
pass mapping[mode],
f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
)
def test_distributed_debug_mode(self): for mode in invalid_debug_modes:
# Default should be off os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
default_debug_mode = dist._get_debug_mode() with self.assertRaisesRegex(RuntimeError, "to be one of"):
self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF) dist._get_debug_mode()
mapping = {
"OFF": dist._DistributedDebugLevel.OFF,
"INFO": dist._DistributedDebugLevel.INFO,
"DETAIL": dist._DistributedDebugLevel.DETAIL,
}
invalid_debug_modes = ["foo", 0, 1, -1]
for mode in mapping.keys():
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
set_debug_mode = dist._get_debug_mode()
self.assertEqual(
set_debug_mode,
mapping[mode],
f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
)
for mode in invalid_debug_modes:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
with self.assertRaisesRegex(RuntimeError, "to be one of"):
dist._get_debug_mode()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -43,17 +43,9 @@ from torch.testing._internal.common_utils import (
TestCase, TestCase,
run_tests, run_tests,
retry_on_connect_failures, retry_on_connect_failures,
TEST_WITH_TSAN,
sandcastle_skip, sandcastle_skip,
) )
if TEST_WITH_TSAN:
print(
"Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment",
file=sys.stderr,
)
sys.exit(0)
def simple_reduce_tests(rank, world_size): def simple_reduce_tests(rank, world_size):
tests = [ tests = [
@ -218,12 +210,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
def setUp(self): def setUp(self):
super(ProcessGroupGlooTest, self).setUp() super(ProcessGroupGlooTest, self).setUp()
self._spawn_processes()
# For Windows platform, Python does not support fork, change it to spawn here.
if sys.platform == "win32":
self._spawn_processes()
else:
self._fork_processes()
def opts(self, threads=2): def opts(self, threads=2):
opts = c10d.ProcessGroupGloo._Options() opts = c10d.ProcessGroupGloo._Options()
@ -1425,10 +1412,7 @@ class DistributedDataParallelTest(
): ):
def setUp(self): def setUp(self):
super(DistributedDataParallelTest, self).setUp() super(DistributedDataParallelTest, self).setUp()
if sys.platform == "win32": self._spawn_processes()
self._spawn_processes()
else:
self._fork_processes()
def _test_gloo_backend( def _test_gloo_backend(
self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
@ -2197,10 +2181,7 @@ class ReducerTest(TestCase):
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
def setUp(self): def setUp(self):
super(CommTest, self).setUp() super(CommTest, self).setUp()
if sys.platform == "win32": self._spawn_processes()
self._spawn_processes()
else:
self._fork_processes()
def tearDown(self): def tearDown(self):
super(CommTest, self).tearDown() super(CommTest, self).tearDown()

View File

@ -45,7 +45,6 @@ from torch.testing._internal.common_utils import (
retry_on_connect_failures, retry_on_connect_failures,
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_ROCM, TEST_WITH_ROCM,
TEST_WITH_TSAN,
sandcastle_skip, sandcastle_skip,
sandcastle_skip_if, sandcastle_skip_if,
) )
@ -57,13 +56,6 @@ if not IS_WINDOWS:
from torch.distributed.optim.functional_adam import _FunctionalAdam from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW from torch.distributed.optim.functional_adamw import _FunctionalAdamW
if TEST_WITH_TSAN:
print(
"Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment",
file=sys.stderr,
)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN: if TEST_WITH_DEV_DBG_ASAN:
print( print(
"Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr

View File

@ -11,7 +11,7 @@ from test_c10d_spawn import _torch_dist_nn_available
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_distributed import requires_gloo, \ from torch.testing._internal.common_distributed import requires_gloo, \
create_device, MultiProcessTestCase, skip_if_lt_x_gpu create_device, MultiProcessTestCase, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_TSAN, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9): if sys.version_info < (3, 9):
@ -76,102 +76,100 @@ if sys.version_info < (3, 9):
self.world_size) self.world_size)
# TSAN is not fork-safe since we're forking in a multi-threaded environment class DistributedDataParallelSingleProcessTest(TestCase):
if not TEST_WITH_TSAN: def setUp(self):
class DistributedDataParallelSingleProcessTest(TestCase): self.rank = 0
def setUp(self): self.world_size = 1
self.rank = 0 self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201
self.world_size = 1
self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201
def tearDown(self): def tearDown(self):
try: try:
os.remove(self.file.name) os.remove(self.file.name)
except OSError: except OSError:
pass pass
def _test_base(self, net, inp, check_allclose=True): def _test_base(self, net, inp, check_allclose=True):
store = c10d.FileStore(self.file.name, self.world_size) store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
if inp[0].is_cuda: if inp[0].is_cuda:
device_ids = [torch.cuda.current_device()] device_ids = [torch.cuda.current_device()]
else: else:
device_ids = None device_ids = None
ddp = nn.parallel.DistributedDataParallel( ddp = nn.parallel.DistributedDataParallel(
copy.deepcopy(net), copy.deepcopy(net),
device_ids=device_ids, device_ids=device_ids,
process_group=process_group process_group=process_group
) )
net_opt = torch.optim.Adam(net.parameters(), lr=0.001) net_opt = torch.optim.Adam(net.parameters(), lr=0.001)
ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001) ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001)
for i, j in zip(ddp.parameters(), net.parameters()):
self.assertTrue(i.allclose(j))
for _ in range(10):
net_out = net(*inp)
ddp_out = ddp(*inp)
net_out.sum().backward()
ddp_out.sum().backward()
net_opt.step()
ddp_opt.step()
if check_allclose:
for i, j in zip(ddp.parameters(), net.parameters()): for i, j in zip(ddp.parameters(), net.parameters()):
self.assertTrue(i.allclose(j)) self.assertTrue(i.allclose(j))
for _ in range(10): @requires_gloo()
net_out = net(*inp) def test_cpu(self):
ddp_out = ddp(*inp) self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)])
net_out.sum().backward() @requires_gloo()
ddp_out.sum().backward() @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
def test_cuda(self):
self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)])
net_opt.step() @requires_gloo()
ddp_opt.step() @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
def test_rnn(self):
# This test is inspired by the bug reported in
# https://github.com/pytorch/pytorch/issues/36268
BATCH_SIZE = 12 # Divisible by 2, 3, 4
INPUT_DIM = 256
OUTPUT_DIM = 256
HIDDEN_DIM = 256
N_LAYERS = 3
SEQ_LEN = 100
if check_allclose: class Net(nn.Module):
for i, j in zip(ddp.parameters(), net.parameters()): def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
self.assertTrue(i.allclose(j)) super(Net, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.hidden_layers = hidden_layers
@requires_gloo() self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True)
def test_cpu(self): self.h2o = nn.Linear(hidden_dim, output_dim)
self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)])
@requires_gloo() def forward(self, x, y):
@sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed") self.lstm.flatten_parameters()
def test_cuda(self): h_t, _ = self.lstm(x)
self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)]) output = self.h2o(h_t)
loss = nn.functional.mse_loss(output, y)
return loss
@requires_gloo() net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0)
@sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed") inp = [
def test_rnn(self): torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0),
# This test is inspired by the bug reported in torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0)
# https://github.com/pytorch/pytorch/issues/36268 ]
BATCH_SIZE = 12 # Divisible by 2, 3, 4
INPUT_DIM = 256
OUTPUT_DIM = 256
HIDDEN_DIM = 256
N_LAYERS = 3
SEQ_LEN = 100
class Net(nn.Module): # Not checking result allclose as the parameter inconsistency exist
def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers): # prior to this change. See #37079
super(Net, self).__init__() self._test_base(net, inp, check_allclose=False)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.hidden_layers = hidden_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True)
self.h2o = nn.Linear(hidden_dim, output_dim)
def forward(self, x, y):
self.lstm.flatten_parameters()
h_t, _ = self.lstm(x)
output = self.h2o(h_t)
loss = nn.functional.mse_loss(output, y)
return loss
net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0)
inp = [
torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0),
torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0)
]
# Not checking result allclose as the parameter inconsistency exist
# prior to this change. See #37079
self._test_base(net, inp, check_allclose=False)
# Skip dev-asan as torch + multiprocessing spawn have known issues # Skip dev-asan as torch + multiprocessing spawn have known issues

View File

@ -1,113 +0,0 @@
import os
import sys
import tempfile
from functools import wraps
import torch
import torch.cuda
import torch.distributed as dist
from torch.testing._internal.common_utils import TEST_WITH_TSAN
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_utils import TestCase, find_free_port, run_tests
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing._internal.distributed.distributed_test import (
DistributedTest, TestDistBackend
)
torch.backends.cuda.matmul.allow_tf32 = False
CPP_EXTENSIONS_WARNING = """
Ninja (https://ninja-build.org) must be available to run C++ extensions tests,
but it could not be found. Install ninja with `pip install ninja`
or `conda install ninja`.
"""
BACKEND = os.environ["BACKEND"]
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
def skip_if_no_ninja(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
import torch.utils.cpp_extension
torch.utils.cpp_extension.verify_ninja_availability()
except RuntimeError:
print(CPP_EXTENSIONS_WARNING)
return 0
return func(*args, **kwargs)
return wrapper
if TEST_WITH_TSAN:
print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr)
sys.exit(0)
if BACKEND == "gloo" or BACKEND == "nccl":
class TestDistBackendWithFork(TestDistBackend, DistributedTest._DistTestBase):
def setUp(self):
super().setUp()
self._fork_processes()
torch.backends.cudnn.flags(allow_tf32=False).__enter__()
elif BACKEND == "mpi":
WORLD_SIZE = os.environ["WORLD_SIZE"]
dist.init_process_group(init_method=INIT_METHOD, backend="mpi")
class TestMPIWithFork(TestCase, DistributedTest._DistTestBase):
pass
elif BACKEND == "test":
class TestBackendDynamicLoad(TestCase):
def setUp(self):
super(TestBackendDynamicLoad, self).setUp()
def _load_test_backend(self):
temp_dir = tempfile.mkdtemp()
src = "{}/../cpp_extensions/cpp_c10d_extension.cpp".format(os.path.abspath(os.path.dirname(__file__)))
extension = torch.utils.cpp_extension.load(
name="torch_test",
sources=[src],
build_directory=temp_dir
)
@skip_if_no_ninja
def test_backend_apis(self):
self._load_test_backend()
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(find_free_port())
os.environ['RANK'] = '0'
dist.init_process_group(backend='test', init_method='env://', world_size=1, rank=0)
self.assertEqual(dist.get_rank(), 0)
self.assertEqual(dist.get_world_size(), 1)
process_group = _get_default_group()
work = process_group.allreduce([torch.rand(1), torch.rand(1)])
self.assertTrue(work.wait())
self.assertTrue(work.is_completed())
self.assertTrue(work.is_success())
work = process_group.broadcast([torch.rand(1)])
self.assertTrue(work.wait())
self.assertTrue(work.is_completed())
self.assertTrue(work.is_success())
dist.destroy_process_group()
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()

View File

@ -6,7 +6,7 @@ import time
from typing import List from typing import List
from torch.testing._internal.common_distributed import requires_nccl, create_tcp_store from torch.testing._internal.common_distributed import requires_nccl, create_tcp_store
from torch.testing._internal.common_utils import load_tests, TEST_WITH_TSAN, run_tests, sandcastle_skip_if from torch.testing._internal.common_utils import load_tests, run_tests, sandcastle_skip_if
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
# load_tests from common_utils is used to automatically filter tests for # load_tests from common_utils is used to automatically filter tests for
@ -29,10 +29,6 @@ def unique_process_group_name(prefix):
now = int(time.time() * 1000) now = int(time.time() * 1000)
return "%s_%d" % (prefix, now) return "%s_%d" % (prefix, now)
if TEST_WITH_TSAN:
print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr)
sys.exit(0)
class ProcessGroupNCCLJitTest(JitTestCase): class ProcessGroupNCCLJitTest(JitTestCase):
MAIN_PROCESS_RANK = 0 MAIN_PROCESS_RANK = 0

View File

@ -12,7 +12,6 @@ if not dist.is_available():
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_TSAN,
TestCase, TestCase,
run_tests, run_tests,
) )
@ -25,10 +24,6 @@ if TEST_WITH_DEV_DBG_ASAN:
print("Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr) print("Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0) sys.exit(0)
if TEST_WITH_TSAN:
print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr)
sys.exit(0)
class TestDistributedLaunch(TestCase): class TestDistributedLaunch(TestCase):
def test_launch_user_script(self): def test_launch_user_script(self):
nnodes = 1 nnodes = 1
@ -41,7 +36,7 @@ class TestDistributedLaunch(TestCase):
f"--nnodes={nnodes}", f"--nnodes={nnodes}",
f"--nproc_per_node={nproc_per_node}", f"--nproc_per_node={nproc_per_node}",
"--monitor_interval=1", "--monitor_interval=1",
"--start_method=fork", "--start_method=spawn",
"--master_addr=localhost", "--master_addr=localhost",
f"--master_port={master_port}", f"--master_port={master_port}",
"--node_rank=0", "--node_rank=0",

View File

@ -20,7 +20,6 @@ from torch.testing._internal.common_distributed import (
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
TEST_WITH_TSAN,
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
) )
@ -28,11 +27,7 @@ from torch.testing._internal.common_utils import (
class AbstractProcessGroupWrapperTest(MultiProcessTestCase): class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
def setUp(self): def setUp(self):
super(AbstractProcessGroupWrapperTest, self).setUp() super(AbstractProcessGroupWrapperTest, self).setUp()
# For Windows platform, Python does not support fork, change it to spawn here. self._spawn_processes()
if sys.platform == "win32":
self._spawn_processes()
else:
self._fork_processes()
def _validate_error(self, exception, op_type, rank, tensor): def _validate_error(self, exception, op_type, rank, tensor):
err = str(exception) err = str(exception)
@ -291,91 +286,89 @@ if not TEST_WITH_DEV_DBG_ASAN:
self._test_collective_shape_mismatch(pg, use_cuda=True) self._test_collective_shape_mismatch(pg, use_cuda=True)
# TSAN is not fork-safe since we're forking in a multi-threaded environment @requires_gloo()
if not TEST_WITH_TSAN: class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
@requires_gloo() def setUp(self):
class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): super(ProcessGroupGlooWrapperTest, self).setUp()
def setUp(self):
super(ProcessGroupGlooWrapperTest, self).setUp()
def opts(self, threads=2, timeout=10.0): def opts(self, threads=2, timeout=10.0):
opts = c10d.ProcessGroupGloo._Options() opts = c10d.ProcessGroupGloo._Options()
opts._timeout = timeout opts._timeout = timeout
opts._devices = [create_device(interface=LOOPBACK)] opts._devices = [create_device(interface=LOOPBACK)]
opts._threads = threads opts._threads = threads
return opts return opts
def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
store = c10d.FileStore(self.file_name, self.world_size) store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group( c10d.init_process_group(
backend="gloo", rank=self.rank, world_size=self.world_size, store=store backend="gloo", rank=self.rank, world_size=self.world_size, store=store
)
if with_new_group:
pg = c10d.new_group(backend="gloo")
else:
_pg = c10d.ProcessGroupGloo(
store, self.rank, self.world_size, self.opts(timeout=timeout)
) )
if with_new_group: pg = c10d._create_process_group_wrapper(
pg = c10d.new_group(backend="gloo") _pg,
else: "unused",
_pg = c10d.ProcessGroupGloo( store,
store, self.rank, self.world_size, self.opts(timeout=timeout) self.rank,
) self.world_size,
pg = c10d._create_process_group_wrapper( timeout=timeout,
_pg, )
"unused", return pg
store,
self.rank,
self.world_size,
timeout=timeout,
)
return pg
def test_collective_hang(self): def test_collective_hang(self):
pg = self._create_wrapper_pg(timeout=2.0) pg = self._create_wrapper_pg(timeout=2.0)
self._test_collective_hang(pg) self._test_collective_hang(pg)
# NOTE: these tests are separated by debug level instead of combined into # NOTE: these tests are separated by debug level instead of combined into
# one due to https://github.com/pytorch/pytorch/issues/55967, they can be # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
# combined after that is resolved. # combined after that is resolved.
@with_dist_debug_levels(levels=["DETAIL"]) @with_dist_debug_levels(levels=["DETAIL"])
def test_collectives_op_mismatch_debug_mode(self): def test_collectives_op_mismatch_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True) pg = self._create_wrapper_pg(with_new_group=True)
self._test_collectives_op_mismatch(pg) self._test_collectives_op_mismatch(pg)
@with_dist_debug_levels(levels=["OFF"]) @with_dist_debug_levels(levels=["OFF"])
def test_collectives_op_mismatch(self): def test_collectives_op_mismatch(self):
pg = self._create_wrapper_pg(with_new_group=False) pg = self._create_wrapper_pg(with_new_group=False)
self._test_collectives_op_mismatch(pg) self._test_collectives_op_mismatch(pg)
@with_dist_debug_levels(levels=["DETAIL"]) @with_dist_debug_levels(levels=["DETAIL"])
def test_collective_shape_mismatch_debug_mode(self): def test_collective_shape_mismatch_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True) pg = self._create_wrapper_pg(with_new_group=True)
self._test_collective_shape_mismatch(pg) self._test_collective_shape_mismatch(pg)
@with_dist_debug_levels(levels=["OFF"]) @with_dist_debug_levels(levels=["OFF"])
def test_collective_shape_mismatch(self): def test_collective_shape_mismatch(self):
pg = self._create_wrapper_pg(with_new_group=False) pg = self._create_wrapper_pg(with_new_group=False)
self._test_collective_shape_mismatch(pg) self._test_collective_shape_mismatch(pg)
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["DETAIL"]) @with_dist_debug_levels(levels=["DETAIL"])
def test_collectives_op_mismatch_cuda_debug_mode(self): def test_collectives_op_mismatch_cuda_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True) pg = self._create_wrapper_pg(with_new_group=True)
self._test_collectives_op_mismatch(pg, use_cuda=True) self._test_collectives_op_mismatch(pg, use_cuda=True)
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["OFF"]) @with_dist_debug_levels(levels=["OFF"])
def test_collectives_op_mismatch_cuda(self): def test_collectives_op_mismatch_cuda(self):
pg = self._create_wrapper_pg(with_new_group=False) pg = self._create_wrapper_pg(with_new_group=False)
self._test_collectives_op_mismatch(pg, use_cuda=True) self._test_collectives_op_mismatch(pg, use_cuda=True)
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["DETAIL"]) @with_dist_debug_levels(levels=["DETAIL"])
def test_collective_shape_mismatch_cuda_debug_mode(self): def test_collective_shape_mismatch_cuda_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True) pg = self._create_wrapper_pg(with_new_group=True)
self._test_collective_shape_mismatch(pg, use_cuda=True) self._test_collective_shape_mismatch(pg, use_cuda=True)
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["OFF"]) @with_dist_debug_levels(levels=["OFF"])
def test_collective_shape_mismatch_cuda(self): def test_collective_shape_mismatch_cuda(self):
pg = self._create_wrapper_pg(with_new_group=False) pg = self._create_wrapper_pg(with_new_group=False)
self._test_collective_shape_mismatch(pg, use_cuda=True) self._test_collective_shape_mismatch(pg, use_cuda=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -65,7 +65,6 @@ TESTS = [
'test_dataloader', 'test_dataloader',
'test_datapipe', 'test_datapipe',
'distributed/test_data_parallel', 'distributed/test_data_parallel',
'distributed/test_distributed_fork',
'distributed/test_distributed_spawn', 'distributed/test_distributed_spawn',
'distributions/test_constraints', 'distributions/test_constraints',
'distributions/test_distributions', 'distributions/test_distributions',
@ -212,7 +211,6 @@ WINDOWS_BLOCKLIST = [
'distributed/rpc/test_faulty_agent', 'distributed/rpc/test_faulty_agent',
'distributed/rpc/test_tensorpipe_agent', 'distributed/rpc/test_tensorpipe_agent',
'distributed/rpc/cuda/test_tensorpipe_agent', 'distributed/rpc/cuda/test_tensorpipe_agent',
'distributed/test_distributed_fork',
'distributed/pipeline/sync/skip/test_api', 'distributed/pipeline/sync/skip/test_api',
'distributed/pipeline/sync/skip/test_gpipe', 'distributed/pipeline/sync/skip/test_gpipe',
'distributed/pipeline/sync/skip/test_inspect_skip_layout', 'distributed/pipeline/sync/skip/test_inspect_skip_layout',
@ -294,7 +292,6 @@ TARGET_DET_LIST = [
'test_testing', 'test_testing',
'test_view_ops', 'test_view_ops',
'distributed/nn/jit/test_instantiator', 'distributed/nn/jit/test_instantiator',
'distributed/test_distributed_fork',
'distributed/rpc/test_tensorpipe_agent', 'distributed/rpc/test_tensorpipe_agent',
'distributed/rpc/cuda/test_tensorpipe_agent', 'distributed/rpc/cuda/test_tensorpipe_agent',
'distributed/algorithms/ddp_comm_hooks/test_ddp_hooks', 'distributed/algorithms/ddp_comm_hooks/test_ddp_hooks',
@ -576,7 +573,7 @@ def test_distributed(test_module, test_directory, options):
os.environ['INIT_METHOD'] = 'env://' os.environ['INIT_METHOD'] = 'env://'
os.environ.update(env_vars) os.environ.update(env_vars)
if with_init_file: if with_init_file:
if test_module in ["test_distributed_fork", "test_distributed_spawn"]: if test_module == "test_distributed_spawn":
init_method = f'{FILE_SCHEMA}{tmp_dir}/' init_method = f'{FILE_SCHEMA}{tmp_dir}/'
else: else:
init_method = f'{FILE_SCHEMA}{tmp_dir}/shared_init_file' init_method = f'{FILE_SCHEMA}{tmp_dir}/shared_init_file'
@ -611,7 +608,6 @@ CUSTOM_HANDLERS = {
'test_cuda_primary_ctx': test_cuda_primary_ctx, 'test_cuda_primary_ctx': test_cuda_primary_ctx,
'test_cpp_extensions_aot_no_ninja': test_cpp_extensions_aot_no_ninja, 'test_cpp_extensions_aot_no_ninja': test_cpp_extensions_aot_no_ninja,
'test_cpp_extensions_aot_ninja': test_cpp_extensions_aot_ninja, 'test_cpp_extensions_aot_ninja': test_cpp_extensions_aot_ninja,
'distributed/test_distributed_fork': test_distributed,
'distributed/test_distributed_spawn': test_distributed, 'distributed/test_distributed_spawn': test_distributed,
} }

View File

@ -16,7 +16,6 @@ class DeterminationTest(unittest.TestCase):
"test_jit_profiling", "test_jit_profiling",
"test_jit", "test_jit",
"test_torch", "test_torch",
"distributed/test_distributed_fork",
"distributed/test_distributed_spawn", "distributed/test_distributed_spawn",
"test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_ninja",
"test_cpp_extensions_aot_no_ninja", "test_cpp_extensions_aot_no_ninja",
@ -104,7 +103,6 @@ class DeterminationTest(unittest.TestCase):
self.assertEqual( self.assertEqual(
self.determined_tests(["torch/utils/cpp_extension.py"]), self.determined_tests(["torch/utils/cpp_extension.py"]),
[ [
"distributed/test_distributed_fork",
"test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_ninja",
"test_cpp_extensions_aot_no_ninja", "test_cpp_extensions_aot_no_ninja",
"test_utils", "test_utils",

View File

@ -630,7 +630,6 @@ class TestFile:
def append(self, test_case: TestCase, test_type: str) -> None: def append(self, test_case: TestCase, test_type: str) -> None:
is_multi_test = self.name == 'test_cpp_extensions_aot' or \ is_multi_test = self.name == 'test_cpp_extensions_aot' or \
self.name == 'distributed/test_distributed_fork' or \
self.name == 'distributed/test_distributed_spawn' or \ self.name == 'distributed/test_distributed_spawn' or \
self.name == 'distributed/test_c10d_gloo' or \ self.name == 'distributed/test_c10d_gloo' or \
self.name == 'cpp' # The caffe2 cpp tests spawn duplicate test cases as well. self.name == 'cpp' # The caffe2 cpp tests spawn duplicate test cases as well.

View File

@ -85,7 +85,6 @@ python test/distributed/test_store.py
python test/distributed/test_pg_wrapper.py python test/distributed/test_pg_wrapper.py
# Run distributed tests, including tests for Distributed Data Parallel. # Run distributed tests, including tests for Distributed Data Parallel.
python test/run_test.py --verbose -i distributed/test_distributed_fork
python test/run_test.py --verbose -i distributed/test_distributed_spawn python test/run_test.py --verbose -i distributed/test_distributed_spawn
# Run the RPC test suite for the TensorPipeAgent. # Run the RPC test suite for the TensorPipeAgent.

View File

@ -9,6 +9,7 @@ import time
import traceback import traceback
import types import types
import unittest import unittest
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
@ -468,6 +469,10 @@ class MultiProcessTestCase(TestCase):
self.processes.append(process) self.processes.append(process)
def _fork_processes(self) -> None: def _fork_processes(self) -> None:
warnings.warn(
"Fork based multiprocessing is dangerous and should not"
" be used, for tests with ASAN consider using opt-asan",
DeprecationWarning)
proc = torch.multiprocessing.get_context("fork").Process proc = torch.multiprocessing.get_context("fork").Process
self._start_processes(proc) self._start_processes(proc)