mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[8/N] Remove c10d/ddp fork tests. (#63454)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63454 Continuation of https://github.com/pytorch/pytorch/pull/63443, this PR removes all fork tests from torch.distributed. ghstack-source-id: 136285511 Test Plan: waitforbuildbot Reviewed By: SciPioneer Differential Revision: D30387872 fbshipit-source-id: f6d6313db126ae7b95b86f78a1e0726887c5c513
This commit is contained in:
		
				
					committed by
					
						
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							71da114412
						
					
				
				
					commit
					2d671ca41b
				
			@ -19,7 +19,6 @@ fi
 | 
				
			|||||||
python tools/download_mnist.py --quiet -d test/cpp/api/mnist
 | 
					python tools/download_mnist.py --quiet -d test/cpp/api/mnist
 | 
				
			||||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api
 | 
					OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api
 | 
				
			||||||
time python test/run_test.py --verbose -i distributed/test_jit_c10d
 | 
					time python test/run_test.py --verbose -i distributed/test_jit_c10d
 | 
				
			||||||
time python test/run_test.py --verbose -i distributed/test_distributed_fork
 | 
					 | 
				
			||||||
time python test/run_test.py --verbose -i distributed/test_c10d_common
 | 
					time python test/run_test.py --verbose -i distributed/test_c10d_common
 | 
				
			||||||
time python test/run_test.py --verbose -i distributed/test_c10d_gloo
 | 
					time python test/run_test.py --verbose -i distributed/test_c10d_gloo
 | 
				
			||||||
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
 | 
					time python test/run_test.py --verbose -i distributed/test_c10d_nccl
 | 
				
			||||||
 | 
				
			|||||||
@ -21,8 +21,14 @@ from torch.testing._internal.common_distributed import (
 | 
				
			|||||||
    requires_nccl,
 | 
					    requires_nccl,
 | 
				
			||||||
    skip_if_lt_x_gpu,
 | 
					    skip_if_lt_x_gpu,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from torch.testing._internal.common_utils import run_tests
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
 | 
					    run_tests,
 | 
				
			||||||
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TEST_WITH_DEV_DBG_ASAN:
 | 
				
			||||||
 | 
					    print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
 | 
				
			||||||
 | 
					    sys.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def gpus_for_rank(world_size):
 | 
					def gpus_for_rank(world_size):
 | 
				
			||||||
    visible_devices = list(range(torch.cuda.device_count()))
 | 
					    visible_devices = list(range(torch.cuda.device_count()))
 | 
				
			||||||
@ -57,7 +63,7 @@ class TestDdpCommHook(nn.Module):
 | 
				
			|||||||
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
 | 
					class DistributedDataParallelCommHookTest(MultiProcessTestCase):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        super(DistributedDataParallelCommHookTest, self).setUp()
 | 
					        super(DistributedDataParallelCommHookTest, self).setUp()
 | 
				
			||||||
        self._fork_processes()
 | 
					        self._spawn_processes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tearDown(self):
 | 
					    def tearDown(self):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
				
			|||||||
@ -37,7 +37,6 @@ from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
 | 
				
			|||||||
from torch.distributed.rpc.backend_registry import BackendType
 | 
					from torch.distributed.rpc.backend_registry import BackendType
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    sandcastle_skip_if,
 | 
					    sandcastle_skip_if,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -406,19 +405,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
            self.assertEqual((100, 100), return_value.shape)
 | 
					            self.assertEqual((100, 100), return_value.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_dummy_compute_c10d(self):
 | 
					    def test_dummy_compute_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_dummy_compute_etcd(self):
 | 
					    def test_dummy_compute_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_dummy_compute_etcd_v2(self):
 | 
					    def test_dummy_compute_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute)
 | 
				
			||||||
@ -431,19 +430,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertIsNone(res.return_values[1])
 | 
					        self.assertIsNone(res.return_values[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_happy_function_c10d(self):
 | 
					    def test_run_happy_function_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_happy_function_etcd(self):
 | 
					    def test_run_happy_function_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_happy_function_etcd_v2(self):
 | 
					    def test_run_happy_function_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_happy_function)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_happy_function)
 | 
				
			||||||
@ -465,13 +464,13 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertIsNone(res.return_values[0])
 | 
					        self.assertIsNone(res.return_values[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_check_master_addr_port_override_etcd(self):
 | 
					    def test_check_master_addr_port_override_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.check_master_addr_port_override)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.check_master_addr_port_override)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_check_master_addr_port_override_etcd_v2(self):
 | 
					    def test_check_master_addr_port_override_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.check_master_addr_port_override)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.check_master_addr_port_override)
 | 
				
			||||||
@ -484,7 +483,7 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertFalse(res.is_failed())
 | 
					        self.assertFalse(res.is_failed())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_check_env_function_etcd(self):
 | 
					    def test_run_check_env_function_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_check_env_function)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_check_env_function)
 | 
				
			||||||
@ -497,19 +496,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertEqual("foo", res.return_values[1])
 | 
					        self.assertEqual("foo", res.return_values[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_function_with_return_value_c10d(self):
 | 
					    def test_run_function_with_return_value_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.run_function_with_return_value)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.run_function_with_return_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_function_with_return_value_etcd(self):
 | 
					    def test_run_function_with_return_value_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_function_with_return_value)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_function_with_return_value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_function_with_return_value_etcd_v2(self):
 | 
					    def test_run_function_with_return_value_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_function_with_return_value)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_function_with_return_value)
 | 
				
			||||||
@ -520,19 +519,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        # _dist_sum internally checks that the sum computed is valid
 | 
					        # _dist_sum internally checks that the sum computed is valid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_simple_dist_sum_c10d(self):
 | 
					    def test_simple_dist_sum_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_simple_dist_sum_etcd(self):
 | 
					    def test_simple_dist_sum_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_simple_dist_sum_etcd_v2(self):
 | 
					    def test_simple_dist_sum_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum)
 | 
				
			||||||
@ -556,19 +555,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertSetEqual(set(range(4 + 4)), ranks)
 | 
					        self.assertSetEqual(set(range(4 + 4)), ranks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_distributed_sum_homogeneous_c10d(self):
 | 
					    def test_run_distributed_sum_homogeneous_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_homogeneous)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_homogeneous)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_distributed_sum_homogeneous_etcd(self):
 | 
					    def test_run_distributed_sum_homogeneous_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_homogeneous)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_homogeneous)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_distributed_sum_homogeneous_etcd_v2(self):
 | 
					    def test_run_distributed_sum_homogeneous_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous)
 | 
				
			||||||
@ -596,19 +595,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertSetEqual(set(range(1 + 2 + 3)), ranks)
 | 
					        self.assertSetEqual(set(range(1 + 2 + 3)), ranks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_distributed_sum_heterogeneous_c10d(self):
 | 
					    def test_run_distributed_sum_heterogeneous_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_distributed_sum_heterogeneous_etcd(self):
 | 
					    def test_run_distributed_sum_heterogeneous_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_distributed_sum_heterogeneous_etcd_v2(self):
 | 
					    def test_run_distributed_sum_heterogeneous_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous)
 | 
				
			||||||
@ -636,19 +635,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
                self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp)
 | 
					                self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_sad_function_c10d(self):
 | 
					    def test_run_sad_function_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_sad_function_etcd(self):
 | 
					    def test_run_sad_function_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_sad_function_etcd_v2(self):
 | 
					    def test_run_sad_function_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function)
 | 
				
			||||||
@ -668,19 +667,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertTrue(agent._total_execution_time > 0)
 | 
					        self.assertTrue(agent._total_execution_time > 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_bipolar_function_c10d(self):
 | 
					    def test_run_bipolar_function_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.run_bipolar_function)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.run_bipolar_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_bipolar_function_etcd(self):
 | 
					    def test_run_bipolar_function_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.run_bipolar_function)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.run_bipolar_function)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_run_bipolar_function_etcd_v2(self):
 | 
					    def test_run_bipolar_function_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_bipolar_function)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_bipolar_function)
 | 
				
			||||||
@ -711,13 +710,13 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_correct_rank_assignment_heterogeneous_etcd(self):
 | 
					    def test_correct_rank_assignment_heterogeneous_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_correct_rank_assignment_heterogeneous_etcd_v2(self):
 | 
					    def test_correct_rank_assignment_heterogeneous_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous)
 | 
				
			||||||
@ -744,13 +743,13 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_correct_rank_assignment_homogeneous_etcd(self):
 | 
					    def test_correct_rank_assignment_homogeneous_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_correct_rank_assignment_homogeneous_etcd_v2(self):
 | 
					    def test_correct_rank_assignment_homogeneous_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous)
 | 
				
			||||||
@ -852,13 +851,13 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
            self.assertEqual(0, p.exitcode)
 | 
					            self.assertEqual(0, p.exitcode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_double_agent_fault_tolerance_etcd(self):
 | 
					    def test_double_agent_fault_tolerance_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_fault_tolerance)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_fault_tolerance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_double_agent_fault_tolerance_etcd_v2(self):
 | 
					    def test_double_agent_fault_tolerance_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance)
 | 
				
			||||||
@ -905,19 +904,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
                self.assertEqual(-signal.SIGKILL, p.exitcode)
 | 
					                self.assertEqual(-signal.SIGKILL, p.exitcode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_double_agent_elastic_c10d(self):
 | 
					    def test_double_agent_elastic_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.double_agent_elastic)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.double_agent_elastic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_double_agent_elastic_etcd(self):
 | 
					    def test_double_agent_elastic_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_elastic)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_elastic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_double_agent_elastic_etcd_v2(self):
 | 
					    def test_double_agent_elastic_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_elastic)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_elastic)
 | 
				
			||||||
@ -955,19 +954,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertEqual([f"{msg} from worker"], list(master_retvals.values()))
 | 
					        self.assertEqual([f"{msg} from worker"], list(master_retvals.values()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_torch_rpc_c10d(self):
 | 
					    def test_torch_rpc_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_torch_rpc_etcd(self):
 | 
					    def test_torch_rpc_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_torch_rpc_etcd_v2(self):
 | 
					    def test_torch_rpc_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc)
 | 
				
			||||||
@ -993,13 +992,13 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
                self.assertEqual(rank, output)
 | 
					                self.assertEqual(rank, output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_workers_drift_success_etcd(self):
 | 
					    def test_workers_drift_success_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_success)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_success)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_workers_drift_success_etcd_v2(self):
 | 
					    def test_workers_drift_success_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_success)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_success)
 | 
				
			||||||
@ -1024,13 +1023,13 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
                self.assertEqual(rank, output)
 | 
					                self.assertEqual(rank, output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_workers_drift_fail_etcd(self):
 | 
					    def test_workers_drift_fail_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_workers_drift_fail_etcd_v2(self):
 | 
					    def test_workers_drift_fail_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_fail)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_fail)
 | 
				
			||||||
@ -1047,19 +1046,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        barrier_mock.assert_called_once()
 | 
					        barrier_mock.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_barrier_failed_c10d(self):
 | 
					    def test_barrier_failed_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_barrier_failed_etcd(self):
 | 
					    def test_barrier_failed_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_barrier_failed_etcd_v2(self):
 | 
					    def test_barrier_failed_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed)
 | 
				
			||||||
@ -1081,19 +1080,19 @@ class LocalElasticAgentTest(unittest.TestCase):
 | 
				
			|||||||
        pcontext_mock.close.assert_called_once()
 | 
					        pcontext_mock.close.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_shutdown_called_c10d(self):
 | 
					    def test_shutdown_called_c10d(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called)
 | 
					        self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_shutdown_called_etcd(self):
 | 
					    def test_shutdown_called_etcd(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called)
 | 
					        self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_shutdown_called_etcd_v2(self):
 | 
					    def test_shutdown_called_etcd_v2(self):
 | 
				
			||||||
        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called)
 | 
					        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called)
 | 
				
			||||||
 | 
				
			|||||||
@ -35,8 +35,8 @@ from torch.distributed.elastic.multiprocessing.errors.error_handler import _writ
 | 
				
			|||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    NO_MULTIPROCESSING_SPAWN,
 | 
					    NO_MULTIPROCESSING_SPAWN,
 | 
				
			||||||
    TEST_WITH_ASAN,
 | 
					    TEST_WITH_ASAN,
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					    TEST_WITH_TSAN,
 | 
				
			||||||
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    IS_IN_CI,
 | 
					    IS_IN_CI,
 | 
				
			||||||
    IS_WINDOWS,
 | 
					    IS_WINDOWS,
 | 
				
			||||||
    IS_MACOS,
 | 
					    IS_MACOS,
 | 
				
			||||||
@ -223,15 +223,11 @@ def start_processes_zombie_test(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# tests incompatible with tsan or asan
 | 
					# tests incompatible with tsan or asan
 | 
				
			||||||
if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
					if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
 | 
				
			||||||
    class StartProcessesTest(unittest.TestCase):
 | 
					    class StartProcessesTest(unittest.TestCase):
 | 
				
			||||||
        def setUp(self):
 | 
					        def setUp(self):
 | 
				
			||||||
            self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
 | 
					            self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
 | 
				
			||||||
 | 
					            self._start_methods = ["spawn"]
 | 
				
			||||||
            if NO_MULTIPROCESSING_SPAWN:  # python 2.7 doesn't have spawn
 | 
					 | 
				
			||||||
                self._start_methods = ["fork"]
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self._start_methods = ["fork", "spawn"]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def tearDown(self):
 | 
					        def tearDown(self):
 | 
				
			||||||
            shutil.rmtree(self.test_dir)
 | 
					            shutil.rmtree(self.test_dir)
 | 
				
			||||||
@ -317,7 +313,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
                args={0: (1,)},
 | 
					                args={0: (1,)},
 | 
				
			||||||
                envs={0: {}},
 | 
					                envs={0: {}},
 | 
				
			||||||
                log_dir=self.log_dir(),
 | 
					                log_dir=self.log_dir(),
 | 
				
			||||||
                start_method="fork",
 | 
					                start_method="spawn",
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
 | 
					            self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
 | 
				
			||||||
@ -332,7 +328,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
                args={0: (1,)},
 | 
					                args={0: (1,)},
 | 
				
			||||||
                envs={0: {}},
 | 
					                envs={0: {}},
 | 
				
			||||||
                log_dir=self.log_dir(),
 | 
					                log_dir=self.log_dir(),
 | 
				
			||||||
                start_method="fork",
 | 
					                start_method="spawn",
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            pids = pc.pids()
 | 
					            pids = pc.pids()
 | 
				
			||||||
@ -387,7 +383,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
                    self.assertEqual({0: None, 1: None}, results.return_values)
 | 
					                    self.assertEqual({0: None, 1: None}, results.return_values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(
 | 
					        @sandcastle_skip_if(
 | 
				
			||||||
            TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
 | 
					            TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        def test_function_large_ret_val(self):
 | 
					        def test_function_large_ret_val(self):
 | 
				
			||||||
            # python multiprocessing.queue module uses pipes and actually PipedQueues
 | 
					            # python multiprocessing.queue module uses pipes and actually PipedQueues
 | 
				
			||||||
@ -549,7 +545,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
 | 
					# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
 | 
				
			||||||
if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
					if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
 | 
				
			||||||
    class StartProcessesListTest(StartProcessesTest):
 | 
					    class StartProcessesListTest(StartProcessesTest):
 | 
				
			||||||
        ########################################
 | 
					        ########################################
 | 
				
			||||||
        # start_processes as binary tests
 | 
					        # start_processes as binary tests
 | 
				
			||||||
@ -630,7 +626,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
                args={0: ("hello",), 1: ("world",)},
 | 
					                args={0: ("hello",), 1: ("world",)},
 | 
				
			||||||
                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
 | 
					                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
 | 
				
			||||||
                log_dir=self.log_dir(),
 | 
					                log_dir=self.log_dir(),
 | 
				
			||||||
                start_method="fork",
 | 
					                start_method="spawn",
 | 
				
			||||||
                redirects={0: Std.ERR, 1: Std.NONE},
 | 
					                redirects={0: Std.ERR, 1: Std.NONE},
 | 
				
			||||||
                tee={0: Std.OUT, 1: Std.ERR},
 | 
					                tee={0: Std.OUT, 1: Std.ERR},
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
@ -647,7 +643,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
 | 
					# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
 | 
				
			||||||
if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS_IN_CI):
 | 
					if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI):
 | 
				
			||||||
    class StartProcessesNotCITest(StartProcessesTest):
 | 
					    class StartProcessesNotCITest(StartProcessesTest):
 | 
				
			||||||
        def test_wrap_bad(self):
 | 
					        def test_wrap_bad(self):
 | 
				
			||||||
            none = ""
 | 
					            none = ""
 | 
				
			||||||
@ -697,8 +693,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            failure = results.failures[0]
 | 
					            failure = results.failures[0]
 | 
				
			||||||
            self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
 | 
					            self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
 | 
				
			||||||
            if TEST_WITH_ASAN:
 | 
					            if TEST_WITH_ASAN or TEST_WITH_TSAN:
 | 
				
			||||||
                # ASAN exit code is 1.
 | 
					                # ASAN/TSAN exit code is 1.
 | 
				
			||||||
                self.assertEqual("<N/A>", failure.signal_name())
 | 
					                self.assertEqual("<N/A>", failure.signal_name())
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.assertEqual("SIGSEGV", failure.signal_name())
 | 
					                self.assertEqual("SIGSEGV", failure.signal_name())
 | 
				
			||||||
@ -714,7 +710,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS
 | 
				
			|||||||
                        args={0: ("hello",), 1: ("world",)},
 | 
					                        args={0: ("hello",), 1: ("world",)},
 | 
				
			||||||
                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
 | 
					                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
 | 
				
			||||||
                        log_dir=log_dir,
 | 
					                        log_dir=log_dir,
 | 
				
			||||||
                        start_method="fork",
 | 
					                        start_method="spawn",
 | 
				
			||||||
                        redirects={0: Std.ERR, 1: Std.NONE},
 | 
					                        redirects={0: Std.ERR, 1: Std.NONE},
 | 
				
			||||||
                        tee={0: Std.OUT, 1: Std.ERR},
 | 
					                        tee={0: Std.OUT, 1: Std.ERR},
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
				
			|||||||
@ -13,7 +13,6 @@ from torch.distributed.elastic.multiprocessing.errors import (
 | 
				
			|||||||
    record,
 | 
					    record,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error
 | 
					from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error
 | 
				
			||||||
from torch.testing._internal.common_utils import TEST_WITH_TSAN
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SentinelError(Exception):
 | 
					class SentinelError(Exception):
 | 
				
			||||||
@ -45,10 +44,6 @@ def read_resource_file(resource_file: str) -> str:
 | 
				
			|||||||
        return "".join(fp.readlines())
 | 
					        return "".join(fp.readlines())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TEST_WITH_TSAN:
 | 
					 | 
				
			||||||
    print("test incompatible with tsan", file=sys.stderr)
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ApiTest(unittest.TestCase):
 | 
					class ApiTest(unittest.TestCase):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
 | 
					        self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
 | 
				
			||||||
 | 
				
			|||||||
@ -15,7 +15,6 @@ import torch.distributed.elastic.timer as timer
 | 
				
			|||||||
import torch.multiprocessing as torch_mp
 | 
					import torch.multiprocessing as torch_mp
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    run_tests,
 | 
					    run_tests,
 | 
				
			||||||
    IS_WINDOWS,
 | 
					    IS_WINDOWS,
 | 
				
			||||||
    IS_MACOS,
 | 
					    IS_MACOS,
 | 
				
			||||||
@ -55,7 +54,7 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
        unittest. As of now this will SIGSEGV.
 | 
					        unittest. As of now this will SIGSEGV.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible")
 | 
					        @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
 | 
				
			||||||
        def test_torch_mp_example(self):
 | 
					        def test_torch_mp_example(self):
 | 
				
			||||||
            # in practice set the max_interval to a larger value (e.g. 60 seconds)
 | 
					            # in practice set the max_interval to a larger value (e.g. 60 seconds)
 | 
				
			||||||
            mp_queue = mp.get_context("spawn").Queue()
 | 
					            mp_queue = mp.get_context("spawn").Queue()
 | 
				
			||||||
@ -80,18 +79,14 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            server.stop()
 | 
					            server.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible")
 | 
					        @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
 | 
				
			||||||
        def test_example_start_method_spawn(self):
 | 
					        def test_example_start_method_spawn(self):
 | 
				
			||||||
            self._run_example_with(start_method="spawn")
 | 
					            self._run_example_with(start_method="spawn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible")
 | 
					        # @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
 | 
				
			||||||
        # def test_example_start_method_forkserver(self):
 | 
					        # def test_example_start_method_forkserver(self):
 | 
				
			||||||
        #     self._run_example_with(start_method="forkserver")
 | 
					        #     self._run_example_with(start_method="forkserver")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
 | 
					 | 
				
			||||||
        def test_example_start_method_fork(self):
 | 
					 | 
				
			||||||
            self._run_example_with(start_method="fork")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def _run_example_with(self, start_method):
 | 
					        def _run_example_with(self, start_method):
 | 
				
			||||||
            spawn_ctx = mp.get_context(start_method)
 | 
					            spawn_ctx = mp.get_context(start_method)
 | 
				
			||||||
            mp_queue = spawn_ctx.Queue()
 | 
					            mp_queue = spawn_ctx.Queue()
 | 
				
			||||||
 | 
				
			|||||||
@ -13,19 +13,28 @@ import torch.distributed.elastic.timer as timer
 | 
				
			|||||||
from torch.distributed.elastic.timer.api import TimerRequest
 | 
					from torch.distributed.elastic.timer.api import TimerRequest
 | 
				
			||||||
from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue
 | 
					from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    run_tests,
 | 
					    run_tests,
 | 
				
			||||||
    IS_WINDOWS,
 | 
					    IS_WINDOWS,
 | 
				
			||||||
    IS_MACOS,
 | 
					    IS_MACOS,
 | 
				
			||||||
    sandcastle_skip_if,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# timer is not supported on windows or macos
 | 
					# timer is not supported on windows or macos
 | 
				
			||||||
if not (IS_WINDOWS or IS_MACOS):
 | 
					if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
 | 
				
			||||||
 | 
					    # func2 should time out
 | 
				
			||||||
 | 
					    def func2(n, mp_queue):
 | 
				
			||||||
 | 
					        if mp_queue is not None:
 | 
				
			||||||
 | 
					            timer.configure(timer.LocalTimerClient(mp_queue))
 | 
				
			||||||
 | 
					        if n > 0:
 | 
				
			||||||
 | 
					            with timer.expires(after=0.1):
 | 
				
			||||||
 | 
					                func2(n - 1, None)
 | 
				
			||||||
 | 
					                time.sleep(0.2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class LocalTimerTest(unittest.TestCase):
 | 
					    class LocalTimerTest(unittest.TestCase):
 | 
				
			||||||
        def setUp(self):
 | 
					        def setUp(self):
 | 
				
			||||||
            self.mp_queue = mp.Queue()
 | 
					            self.ctx = mp.get_context('spawn')
 | 
				
			||||||
 | 
					            self.mp_queue = self.ctx.Queue()
 | 
				
			||||||
            self.max_interval = 0.01
 | 
					            self.max_interval = 0.01
 | 
				
			||||||
            self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
 | 
					            self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
 | 
				
			||||||
            self.server.start()
 | 
					            self.server.start()
 | 
				
			||||||
@ -62,7 +71,6 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
            with timer.expires(after=0.5):
 | 
					            with timer.expires(after=0.5):
 | 
				
			||||||
                time.sleep(0.1)
 | 
					                time.sleep(0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
 | 
					 | 
				
			||||||
        def test_get_timer_recursive(self):
 | 
					        def test_get_timer_recursive(self):
 | 
				
			||||||
            """
 | 
					            """
 | 
				
			||||||
            If a function acquires a countdown timer with default scope,
 | 
					            If a function acquires a countdown timer with default scope,
 | 
				
			||||||
@ -82,14 +90,7 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            func(4)
 | 
					            func(4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # func2 should time out
 | 
					            p = self.ctx.Process(target=func2, args=(2, self.mp_queue))
 | 
				
			||||||
            def func2(n):
 | 
					 | 
				
			||||||
                if n > 0:
 | 
					 | 
				
			||||||
                    with timer.expires(after=0.1):
 | 
					 | 
				
			||||||
                        func2(n - 1)
 | 
					 | 
				
			||||||
                        time.sleep(0.2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            p = mp.Process(target=func2, args=(2,))
 | 
					 | 
				
			||||||
            p.start()
 | 
					            p.start()
 | 
				
			||||||
            p.join()
 | 
					            p.join()
 | 
				
			||||||
            self.assertEqual(-signal.SIGKILL, p.exitcode)
 | 
					            self.assertEqual(-signal.SIGKILL, p.exitcode)
 | 
				
			||||||
@ -102,7 +103,6 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
            with timer.expires(after=timeout):
 | 
					            with timer.expires(after=timeout):
 | 
				
			||||||
                time.sleep(duration)
 | 
					                time.sleep(duration)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
 | 
					 | 
				
			||||||
        def test_timer(self):
 | 
					        def test_timer(self):
 | 
				
			||||||
            timeout = 0.1
 | 
					            timeout = 0.1
 | 
				
			||||||
            duration = 1
 | 
					            duration = 1
 | 
				
			||||||
@ -124,7 +124,7 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# timer is not supported on windows or macos
 | 
					# timer is not supported on windows or macos
 | 
				
			||||||
if not (IS_WINDOWS or IS_MACOS):
 | 
					if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
 | 
				
			||||||
    class MultiprocessingRequestQueueTest(unittest.TestCase):
 | 
					    class MultiprocessingRequestQueueTest(unittest.TestCase):
 | 
				
			||||||
        def test_get(self):
 | 
					        def test_get(self):
 | 
				
			||||||
            mp_queue = mp.Queue()
 | 
					            mp_queue = mp.Queue()
 | 
				
			||||||
@ -183,7 +183,7 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# timer is not supported on windows or macos
 | 
					# timer is not supported on windows or macos
 | 
				
			||||||
if not (IS_WINDOWS or IS_MACOS):
 | 
					if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
 | 
				
			||||||
    class LocalTimerServerTest(unittest.TestCase):
 | 
					    class LocalTimerServerTest(unittest.TestCase):
 | 
				
			||||||
        def setUp(self):
 | 
					        def setUp(self):
 | 
				
			||||||
            self.mp_queue = mp.Queue()
 | 
					            self.mp_queue = mp.Queue()
 | 
				
			||||||
@ -193,7 +193,6 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
        def tearDown(self):
 | 
					        def tearDown(self):
 | 
				
			||||||
            self.server.stop()
 | 
					            self.server.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
 | 
					 | 
				
			||||||
        def test_watchdog_call_count(self):
 | 
					        def test_watchdog_call_count(self):
 | 
				
			||||||
            """
 | 
					            """
 | 
				
			||||||
            checks that the watchdog function ran wait/interval +- 1 times
 | 
					            checks that the watchdog function ran wait/interval +- 1 times
 | 
				
			||||||
@ -226,7 +225,6 @@ if not (IS_WINDOWS or IS_MACOS):
 | 
				
			|||||||
        def _release_timer(self, pid, scope):
 | 
					        def _release_timer(self, pid, scope):
 | 
				
			||||||
            return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1)
 | 
					            return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible")
 | 
					 | 
				
			||||||
        @mock.patch("os.kill")
 | 
					        @mock.patch("os.kill")
 | 
				
			||||||
        def test_expired_timers(self, mock_os_kill):
 | 
					        def test_expired_timers(self, mock_os_kill):
 | 
				
			||||||
            """
 | 
					            """
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,6 @@ from torch.distributed.launcher.api import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    sandcastle_skip_if,
 | 
					    sandcastle_skip_if,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -117,7 +116,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            rdzv_endpoint=endpoint,
 | 
					            rdzv_endpoint=endpoint,
 | 
				
			||||||
            monitor_interval=1,
 | 
					            monitor_interval=1,
 | 
				
			||||||
            rdzv_backend=rdzv_backend,
 | 
					            rdzv_backend=rdzv_backend,
 | 
				
			||||||
            start_method="fork",
 | 
					            start_method="spawn",
 | 
				
			||||||
            max_restarts=0,
 | 
					            max_restarts=0,
 | 
				
			||||||
            rdzv_configs=rdzv_configs,
 | 
					            rdzv_configs=rdzv_configs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@ -128,7 +127,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_script_python(self):
 | 
					    def test_launch_script_python(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -145,7 +144,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        self.check_works_ran(world_size)
 | 
					        self.check_works_ran(world_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_script_python_local_rank_transfer(self):
 | 
					    def test_launch_script_python_local_rank_transfer(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -162,7 +161,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        self.check_works_ran(world_size)
 | 
					        self.check_works_ran(world_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_script_bash(self):
 | 
					    def test_launch_script_bash(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -177,7 +176,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        self.check_works_ran(world_size)
 | 
					        self.check_works_ran(world_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_function(self):
 | 
					    def test_launch_function(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -193,7 +192,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertEqual(expected_res, actual_res)
 | 
					        self.assertEqual(expected_res, actual_res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_dist_sum_with_static_rdzv(self):
 | 
					    def test_launch_dist_sum_with_static_rdzv(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -224,7 +223,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertEqual(expected_res, actual_res)
 | 
					        self.assertEqual(expected_res, actual_res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_elastic(self):
 | 
					    def test_launch_elastic(self):
 | 
				
			||||||
        nproc_per_node = 4
 | 
					        nproc_per_node = 4
 | 
				
			||||||
 | 
				
			|||||||
@ -15,7 +15,6 @@ import torch.distributed.launch as launch
 | 
				
			|||||||
from torch.distributed.elastic.utils import get_socket_with_port
 | 
					from torch.distributed.elastic.utils import get_socket_with_port
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    sandcastle_skip_if,
 | 
					    sandcastle_skip_if,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -36,7 +35,7 @@ class LaunchTest(unittest.TestCase):
 | 
				
			|||||||
        shutil.rmtree(self.test_dir)
 | 
					        shutil.rmtree(self.test_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_without_env(self):
 | 
					    def test_launch_without_env(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -49,7 +48,7 @@ class LaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--nnodes={nnodes}",
 | 
					            f"--nnodes={nnodes}",
 | 
				
			||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--master_addr=localhost",
 | 
					            "--master_addr=localhost",
 | 
				
			||||||
            f"--master_port={master_port}",
 | 
					            f"--master_port={master_port}",
 | 
				
			||||||
            "--node_rank=0",
 | 
					            "--node_rank=0",
 | 
				
			||||||
@ -58,7 +57,7 @@ class LaunchTest(unittest.TestCase):
 | 
				
			|||||||
        launch.main(args)
 | 
					        launch.main(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(
 | 
					    @sandcastle_skip_if(
 | 
				
			||||||
        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan and dev/dbg asan"
 | 
					        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def test_launch_with_env(self):
 | 
					    def test_launch_with_env(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -71,7 +70,7 @@ class LaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--nnodes={nnodes}",
 | 
					            f"--nnodes={nnodes}",
 | 
				
			||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--master_addr=localhost",
 | 
					            "--master_addr=localhost",
 | 
				
			||||||
            f"--master_port={master_port}",
 | 
					            f"--master_port={master_port}",
 | 
				
			||||||
            "--node_rank=0",
 | 
					            "--node_rank=0",
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,6 @@ from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
 | 
				
			|||||||
from torch.distributed.elastic.utils import get_socket_with_port
 | 
					from torch.distributed.elastic.utils import get_socket_with_port
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    sandcastle_skip_if,
 | 
					    sandcastle_skip_if,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -100,7 +99,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
					            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
				
			||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -123,7 +122,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--nnodes={nnodes}",
 | 
					            f"--nnodes={nnodes}",
 | 
				
			||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--master_addr=localhost",
 | 
					            "--master_addr=localhost",
 | 
				
			||||||
            f"--master_port={master_port}",
 | 
					            f"--master_port={master_port}",
 | 
				
			||||||
            "--node_rank=0",
 | 
					            "--node_rank=0",
 | 
				
			||||||
@ -138,7 +137,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
					            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_user_script_bash(self):
 | 
					    def test_launch_user_script_bash(self):
 | 
				
			||||||
        run_id = str(uuid.uuid4().int)
 | 
					        run_id = str(uuid.uuid4().int)
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -151,7 +150,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
					            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
				
			||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--no_python",
 | 
					            "--no_python",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -169,7 +168,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
					            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_user_script_default_nproc(self):
 | 
					    def test_launch_user_script_default_nproc(self):
 | 
				
			||||||
        run_id = str(uuid.uuid4().int)
 | 
					        run_id = str(uuid.uuid4().int)
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -180,7 +179,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
					            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
				
			||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--no_python",
 | 
					            "--no_python",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -198,7 +197,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
					            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_with_env_vars(self):
 | 
					    def test_launch_with_env_vars(self):
 | 
				
			||||||
        run_id = str(uuid.uuid4().int)
 | 
					        run_id = str(uuid.uuid4().int)
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -211,7 +210,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
 | 
					        os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
 | 
				
			||||||
        os.environ["PET_RDZV_ID"] = run_id
 | 
					        os.environ["PET_RDZV_ID"] = run_id
 | 
				
			||||||
        os.environ["PET_MONITOR_INTERVAL"] = "1"
 | 
					        os.environ["PET_MONITOR_INTERVAL"] = "1"
 | 
				
			||||||
        os.environ["PET_START_METHOD"] = "fork"
 | 
					        os.environ["PET_START_METHOD"] = "spawn"
 | 
				
			||||||
        os.environ["PET_NO_PYTHON"] = "1"
 | 
					        os.environ["PET_NO_PYTHON"] = "1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
 | 
					        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
 | 
				
			||||||
@ -241,7 +240,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
					            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
				
			||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--no_python",
 | 
					            "--no_python",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -256,27 +255,27 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
					            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_nproc_launch_auto_configurations(self):
 | 
					    def test_nproc_launch_auto_configurations(self):
 | 
				
			||||||
        self._test_nproc_launch_configuration("auto", os.cpu_count())
 | 
					        self._test_nproc_launch_configuration("auto", os.cpu_count())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_nproc_launch_number_configurations(self):
 | 
					    def test_nproc_launch_number_configurations(self):
 | 
				
			||||||
        self._test_nproc_launch_configuration("4", 4)
 | 
					        self._test_nproc_launch_configuration("4", 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_nproc_launch_unknown_configurations(self):
 | 
					    def test_nproc_launch_unknown_configurations(self):
 | 
				
			||||||
        with self.assertRaises(ValueError):
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
            self._test_nproc_launch_configuration("unknown", 4)
 | 
					            self._test_nproc_launch_configuration("unknown", 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    @patch("torch.cuda.is_available", return_value=True)
 | 
					    @patch("torch.cuda.is_available", return_value=True)
 | 
				
			||||||
    @patch("torch.cuda.device_count", return_value=3)
 | 
					    @patch("torch.cuda.device_count", return_value=3)
 | 
				
			||||||
    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
 | 
					    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
 | 
				
			||||||
        self._test_nproc_launch_configuration("auto", 3)
 | 
					        self._test_nproc_launch_configuration("auto", 3)
 | 
				
			||||||
        self._test_nproc_launch_configuration("gpu", 3)
 | 
					        self._test_nproc_launch_configuration("gpu", 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_elastic(self):
 | 
					    def test_launch_elastic(self):
 | 
				
			||||||
        run_id = str(uuid.uuid4().int)
 | 
					        run_id = str(uuid.uuid4().int)
 | 
				
			||||||
        min_nodes = 1
 | 
					        min_nodes = 1
 | 
				
			||||||
@ -291,7 +290,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
					            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
				
			||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -304,7 +303,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @mock.patch("torch.distributed.elastic.events.record")
 | 
					    @mock.patch("torch.distributed.elastic.events.record")
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_elastic_worker_raise_exception(self, record_mock):
 | 
					    def test_launch_elastic_worker_raise_exception(self, record_mock):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Asserts that when the worker program fails and lancher raieses exception
 | 
					        Asserts that when the worker program fails and lancher raieses exception
 | 
				
			||||||
@ -323,7 +322,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--max_restarts=0",
 | 
					            "--max_restarts=0",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            "--fail",
 | 
					            "--fail",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -332,7 +331,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        record_mock.assert_called_once()
 | 
					        record_mock.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    @mock.patch(
 | 
					    @mock.patch(
 | 
				
			||||||
        "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run"
 | 
					        "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@ -354,7 +353,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--max_restarts=0",
 | 
					            "--max_restarts=0",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -364,7 +363,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            launch.main(args)
 | 
					            launch.main(args)
 | 
				
			||||||
        record_mock.assert_called_once()
 | 
					        record_mock.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_standalone(self):
 | 
					    def test_launch_standalone(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
        nproc_per_node = 4
 | 
					        nproc_per_node = 4
 | 
				
			||||||
@ -374,7 +373,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--standalone",
 | 
					            "--standalone",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -386,7 +385,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
					            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_run_path(self):
 | 
					    def test_launch_run_path(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
        nproc_per_node = 4
 | 
					        nproc_per_node = 4
 | 
				
			||||||
@ -396,7 +395,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--nnodes={nnodes}",
 | 
					            f"--nnodes={nnodes}",
 | 
				
			||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -408,7 +407,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
					            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with tsan and dev/dbg asan")
 | 
					    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
				
			||||||
    def test_launch_elastic_multiple_agents(self):
 | 
					    def test_launch_elastic_multiple_agents(self):
 | 
				
			||||||
        run_id = str(uuid.uuid4().int)
 | 
					        run_id = str(uuid.uuid4().int)
 | 
				
			||||||
        min_nodes = 1
 | 
					        min_nodes = 1
 | 
				
			||||||
@ -423,7 +422,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
					            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
				
			||||||
            f"--rdzv_id={run_id}",
 | 
					            f"--rdzv_id={run_id}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
@ -462,7 +461,7 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
				
			|||||||
            f"--nnodes={nnodes}",
 | 
					            f"--nnodes={nnodes}",
 | 
				
			||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            path("bin/test_script.py"),
 | 
					            path("bin/test_script.py"),
 | 
				
			||||||
            f"--touch_file_dir={self.test_dir}",
 | 
					            f"--touch_file_dir={self.test_dir}",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
				
			|||||||
@ -28,9 +28,13 @@ from torch.testing._internal.common_utils import (
 | 
				
			|||||||
    TestCase,
 | 
					    TestCase,
 | 
				
			||||||
    load_tests,
 | 
					    load_tests,
 | 
				
			||||||
    run_tests,
 | 
					    run_tests,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TEST_WITH_DEV_DBG_ASAN:
 | 
				
			||||||
 | 
					    print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr)
 | 
				
			||||||
 | 
					    sys.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# load_tests from common_utils is used to automatically filter tests for
 | 
					# load_tests from common_utils is used to automatically filter tests for
 | 
				
			||||||
# sharding on sandcastle. This line silences flake warnings
 | 
					# sharding on sandcastle. This line silences flake warnings
 | 
				
			||||||
load_tests = load_tests
 | 
					load_tests = load_tests
 | 
				
			||||||
@ -438,37 +442,31 @@ class AbstractDistributedDataParallelTest(object):
 | 
				
			|||||||
        return fut.then(fut_then)
 | 
					        return fut.then(fut_then)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TSAN is not fork-safe since we're forking in a multi-threaded environment
 | 
					class DistributedDataParallelTest(
 | 
				
			||||||
if not TEST_WITH_TSAN:
 | 
					    AbstractDistributedDataParallelTest, MultiProcessTestCase
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        super(DistributedDataParallelTest, self).setUp()
 | 
				
			||||||
 | 
					        self._spawn_processes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class DistributedDataParallelTest(
 | 
					    def test_invalid_powerSGD_state(self):
 | 
				
			||||||
        AbstractDistributedDataParallelTest, MultiProcessTestCase
 | 
					        for start_powerSGD_iter, use_error_feedback, warm_start in product(
 | 
				
			||||||
    ):
 | 
					            [0, 1], [True, False], [True, False]
 | 
				
			||||||
        def setUp(self):
 | 
					        ):
 | 
				
			||||||
            super(DistributedDataParallelTest, self).setUp()
 | 
					            if not use_error_feedback and not warm_start:
 | 
				
			||||||
            if sys.platform == "win32":
 | 
					                continue
 | 
				
			||||||
                self._spawn_processes()
 | 
					            with self.assertRaisesRegex(
 | 
				
			||||||
            else:
 | 
					                ValueError,
 | 
				
			||||||
                self._fork_processes()
 | 
					                "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
 | 
				
			||||||
 | 
					                "because PowerSGD can only be applied after the first two iterations in DDP.",
 | 
				
			||||||
        def test_invalid_powerSGD_state(self):
 | 
					 | 
				
			||||||
            for start_powerSGD_iter, use_error_feedback, warm_start in product(
 | 
					 | 
				
			||||||
                [0, 1], [True, False], [True, False]
 | 
					 | 
				
			||||||
            ):
 | 
					            ):
 | 
				
			||||||
                if not use_error_feedback and not warm_start:
 | 
					                state = powerSGD.PowerSGDState(
 | 
				
			||||||
                    continue
 | 
					                    process_group=None,
 | 
				
			||||||
                with self.assertRaisesRegex(
 | 
					                    matrix_approximation_rank=1,
 | 
				
			||||||
                    ValueError,
 | 
					                    start_powerSGD_iter=start_powerSGD_iter,
 | 
				
			||||||
                    "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
 | 
					                    use_error_feedback=use_error_feedback,
 | 
				
			||||||
                    "because PowerSGD can only be applied after the first two iterations in DDP.",
 | 
					                    warm_start=warm_start,
 | 
				
			||||||
                ):
 | 
					                )
 | 
				
			||||||
                    state = powerSGD.PowerSGDState(
 | 
					 | 
				
			||||||
                        process_group=None,
 | 
					 | 
				
			||||||
                        matrix_approximation_rank=1,
 | 
					 | 
				
			||||||
                        start_powerSGD_iter=start_powerSGD_iter,
 | 
					 | 
				
			||||||
                        use_error_feedback=use_error_feedback,
 | 
					 | 
				
			||||||
                        warm_start=warm_start,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ComputeBucketAssignmentTest(TestCase):
 | 
					class ComputeBucketAssignmentTest(TestCase):
 | 
				
			||||||
@ -656,49 +654,42 @@ class AbstractCommTest(object):
 | 
				
			|||||||
            dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
 | 
					            dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
 | 
				
			||||||
            self.assertEqual(len(set(obj_list)), 1)
 | 
					            self.assertEqual(len(set(obj_list)), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CommTest(AbstractCommTest, MultiProcessTestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        super(CommTest, self).setUp()
 | 
				
			||||||
 | 
					        self._spawn_processes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TSAN is not fork-safe since we're forking in a multi-threaded environment
 | 
					    def tearDown(self):
 | 
				
			||||||
if not TEST_WITH_TSAN:
 | 
					        super(CommTest, self).tearDown()
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            os.remove(self.file_name)
 | 
				
			||||||
 | 
					        except OSError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class CommTest(AbstractCommTest, MultiProcessTestCase):
 | 
					    def test_distributed_debug_mode(self):
 | 
				
			||||||
        def setUp(self):
 | 
					        # Default should be off
 | 
				
			||||||
            super(CommTest, self).setUp()
 | 
					        default_debug_mode = dist._get_debug_mode()
 | 
				
			||||||
            if sys.platform == "win32":
 | 
					        self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF)
 | 
				
			||||||
                self._spawn_processes()
 | 
					        mapping = {
 | 
				
			||||||
            else:
 | 
					            "OFF": dist._DistributedDebugLevel.OFF,
 | 
				
			||||||
                self._fork_processes()
 | 
					            "INFO": dist._DistributedDebugLevel.INFO,
 | 
				
			||||||
 | 
					            "DETAIL": dist._DistributedDebugLevel.DETAIL,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        invalid_debug_modes = ["foo", 0, 1, -1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def tearDown(self):
 | 
					        for mode in mapping.keys():
 | 
				
			||||||
            super(CommTest, self).tearDown()
 | 
					            os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
 | 
				
			||||||
            try:
 | 
					            set_debug_mode = dist._get_debug_mode()
 | 
				
			||||||
                os.remove(self.file_name)
 | 
					            self.assertEqual(
 | 
				
			||||||
            except OSError:
 | 
					                set_debug_mode,
 | 
				
			||||||
                pass
 | 
					                mapping[mode],
 | 
				
			||||||
 | 
					                f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_distributed_debug_mode(self):
 | 
					        for mode in invalid_debug_modes:
 | 
				
			||||||
            # Default should be off
 | 
					            os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
 | 
				
			||||||
            default_debug_mode = dist._get_debug_mode()
 | 
					            with self.assertRaisesRegex(RuntimeError, "to be one of"):
 | 
				
			||||||
            self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF)
 | 
					                dist._get_debug_mode()
 | 
				
			||||||
            mapping = {
 | 
					 | 
				
			||||||
                "OFF": dist._DistributedDebugLevel.OFF,
 | 
					 | 
				
			||||||
                "INFO": dist._DistributedDebugLevel.INFO,
 | 
					 | 
				
			||||||
                "DETAIL": dist._DistributedDebugLevel.DETAIL,
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            invalid_debug_modes = ["foo", 0, 1, -1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for mode in mapping.keys():
 | 
					 | 
				
			||||||
                os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
 | 
					 | 
				
			||||||
                set_debug_mode = dist._get_debug_mode()
 | 
					 | 
				
			||||||
                self.assertEqual(
 | 
					 | 
				
			||||||
                    set_debug_mode,
 | 
					 | 
				
			||||||
                    mapping[mode],
 | 
					 | 
				
			||||||
                    f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for mode in invalid_debug_modes:
 | 
					 | 
				
			||||||
                os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
 | 
					 | 
				
			||||||
                with self.assertRaisesRegex(RuntimeError, "to be one of"):
 | 
					 | 
				
			||||||
                    dist._get_debug_mode()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
@ -43,17 +43,9 @@ from torch.testing._internal.common_utils import (
 | 
				
			|||||||
    TestCase,
 | 
					    TestCase,
 | 
				
			||||||
    run_tests,
 | 
					    run_tests,
 | 
				
			||||||
    retry_on_connect_failures,
 | 
					    retry_on_connect_failures,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    sandcastle_skip,
 | 
					    sandcastle_skip,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TEST_WITH_TSAN:
 | 
					 | 
				
			||||||
    print(
 | 
					 | 
				
			||||||
        "Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment",
 | 
					 | 
				
			||||||
        file=sys.stderr,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def simple_reduce_tests(rank, world_size):
 | 
					def simple_reduce_tests(rank, world_size):
 | 
				
			||||||
    tests = [
 | 
					    tests = [
 | 
				
			||||||
@ -218,12 +210,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        super(ProcessGroupGlooTest, self).setUp()
 | 
					        super(ProcessGroupGlooTest, self).setUp()
 | 
				
			||||||
 | 
					        self._spawn_processes()
 | 
				
			||||||
        # For Windows platform, Python does not support fork, change it to spawn here.
 | 
					 | 
				
			||||||
        if sys.platform == "win32":
 | 
					 | 
				
			||||||
            self._spawn_processes()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._fork_processes()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def opts(self, threads=2):
 | 
					    def opts(self, threads=2):
 | 
				
			||||||
        opts = c10d.ProcessGroupGloo._Options()
 | 
					        opts = c10d.ProcessGroupGloo._Options()
 | 
				
			||||||
@ -1425,10 +1412,7 @@ class DistributedDataParallelTest(
 | 
				
			|||||||
):
 | 
					):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        super(DistributedDataParallelTest, self).setUp()
 | 
					        super(DistributedDataParallelTest, self).setUp()
 | 
				
			||||||
        if sys.platform == "win32":
 | 
					        self._spawn_processes()
 | 
				
			||||||
            self._spawn_processes()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._fork_processes()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _test_gloo_backend(
 | 
					    def _test_gloo_backend(
 | 
				
			||||||
        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
 | 
					        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
 | 
				
			||||||
@ -2197,10 +2181,7 @@ class ReducerTest(TestCase):
 | 
				
			|||||||
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
 | 
					class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        super(CommTest, self).setUp()
 | 
					        super(CommTest, self).setUp()
 | 
				
			||||||
        if sys.platform == "win32":
 | 
					        self._spawn_processes()
 | 
				
			||||||
            self._spawn_processes()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._fork_processes()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tearDown(self):
 | 
					    def tearDown(self):
 | 
				
			||||||
        super(CommTest, self).tearDown()
 | 
					        super(CommTest, self).tearDown()
 | 
				
			||||||
 | 
				
			|||||||
@ -45,7 +45,6 @@ from torch.testing._internal.common_utils import (
 | 
				
			|||||||
    retry_on_connect_failures,
 | 
					    retry_on_connect_failures,
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_ROCM,
 | 
					    TEST_WITH_ROCM,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    sandcastle_skip,
 | 
					    sandcastle_skip,
 | 
				
			||||||
    sandcastle_skip_if,
 | 
					    sandcastle_skip_if,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -57,13 +56,6 @@ if not IS_WINDOWS:
 | 
				
			|||||||
    from torch.distributed.optim.functional_adam import _FunctionalAdam
 | 
					    from torch.distributed.optim.functional_adam import _FunctionalAdam
 | 
				
			||||||
    from torch.distributed.optim.functional_adamw import _FunctionalAdamW
 | 
					    from torch.distributed.optim.functional_adamw import _FunctionalAdamW
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TEST_WITH_TSAN:
 | 
					 | 
				
			||||||
    print(
 | 
					 | 
				
			||||||
        "Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment",
 | 
					 | 
				
			||||||
        file=sys.stderr,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if TEST_WITH_DEV_DBG_ASAN:
 | 
					if TEST_WITH_DEV_DBG_ASAN:
 | 
				
			||||||
    print(
 | 
					    print(
 | 
				
			||||||
        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
 | 
					        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,7 @@ from test_c10d_spawn import _torch_dist_nn_available
 | 
				
			|||||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
 | 
					from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
 | 
				
			||||||
from torch.testing._internal.common_distributed import requires_gloo, \
 | 
					from torch.testing._internal.common_distributed import requires_gloo, \
 | 
				
			||||||
    create_device, MultiProcessTestCase, skip_if_lt_x_gpu
 | 
					    create_device, MultiProcessTestCase, skip_if_lt_x_gpu
 | 
				
			||||||
from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_TSAN, TEST_WITH_DEV_DBG_ASAN
 | 
					from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
 | 
					# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
 | 
				
			||||||
if sys.version_info < (3, 9):
 | 
					if sys.version_info < (3, 9):
 | 
				
			||||||
@ -76,102 +76,100 @@ if sys.version_info < (3, 9):
 | 
				
			|||||||
                self.world_size)
 | 
					                self.world_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TSAN is not fork-safe since we're forking in a multi-threaded environment
 | 
					class DistributedDataParallelSingleProcessTest(TestCase):
 | 
				
			||||||
if not TEST_WITH_TSAN:
 | 
					    def setUp(self):
 | 
				
			||||||
    class DistributedDataParallelSingleProcessTest(TestCase):
 | 
					        self.rank = 0
 | 
				
			||||||
        def setUp(self):
 | 
					        self.world_size = 1
 | 
				
			||||||
            self.rank = 0
 | 
					        self.file = tempfile.NamedTemporaryFile(delete=False)  # noqa: P201
 | 
				
			||||||
            self.world_size = 1
 | 
					 | 
				
			||||||
            self.file = tempfile.NamedTemporaryFile(delete=False)  # noqa: P201
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def tearDown(self):
 | 
					    def tearDown(self):
 | 
				
			||||||
            try:
 | 
					        try:
 | 
				
			||||||
                os.remove(self.file.name)
 | 
					            os.remove(self.file.name)
 | 
				
			||||||
            except OSError:
 | 
					        except OSError:
 | 
				
			||||||
                pass
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def _test_base(self, net, inp, check_allclose=True):
 | 
					    def _test_base(self, net, inp, check_allclose=True):
 | 
				
			||||||
            store = c10d.FileStore(self.file.name, self.world_size)
 | 
					        store = c10d.FileStore(self.file.name, self.world_size)
 | 
				
			||||||
            process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
 | 
					        process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
 | 
				
			||||||
            if inp[0].is_cuda:
 | 
					        if inp[0].is_cuda:
 | 
				
			||||||
                device_ids = [torch.cuda.current_device()]
 | 
					            device_ids = [torch.cuda.current_device()]
 | 
				
			||||||
            else:
 | 
					        else:
 | 
				
			||||||
                device_ids = None
 | 
					            device_ids = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ddp = nn.parallel.DistributedDataParallel(
 | 
					        ddp = nn.parallel.DistributedDataParallel(
 | 
				
			||||||
                copy.deepcopy(net),
 | 
					            copy.deepcopy(net),
 | 
				
			||||||
                device_ids=device_ids,
 | 
					            device_ids=device_ids,
 | 
				
			||||||
                process_group=process_group
 | 
					            process_group=process_group
 | 
				
			||||||
            )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            net_opt = torch.optim.Adam(net.parameters(), lr=0.001)
 | 
					        net_opt = torch.optim.Adam(net.parameters(), lr=0.001)
 | 
				
			||||||
            ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001)
 | 
					        ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, j in zip(ddp.parameters(), net.parameters()):
 | 
				
			||||||
 | 
					            self.assertTrue(i.allclose(j))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for _ in range(10):
 | 
				
			||||||
 | 
					            net_out = net(*inp)
 | 
				
			||||||
 | 
					            ddp_out = ddp(*inp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            net_out.sum().backward()
 | 
				
			||||||
 | 
					            ddp_out.sum().backward()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            net_opt.step()
 | 
				
			||||||
 | 
					            ddp_opt.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if check_allclose:
 | 
				
			||||||
            for i, j in zip(ddp.parameters(), net.parameters()):
 | 
					            for i, j in zip(ddp.parameters(), net.parameters()):
 | 
				
			||||||
                self.assertTrue(i.allclose(j))
 | 
					                self.assertTrue(i.allclose(j))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for _ in range(10):
 | 
					    @requires_gloo()
 | 
				
			||||||
                net_out = net(*inp)
 | 
					    def test_cpu(self):
 | 
				
			||||||
                ddp_out = ddp(*inp)
 | 
					        self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net_out.sum().backward()
 | 
					    @requires_gloo()
 | 
				
			||||||
                ddp_out.sum().backward()
 | 
					    @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
 | 
				
			||||||
 | 
					    def test_cuda(self):
 | 
				
			||||||
 | 
					        self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                net_opt.step()
 | 
					    @requires_gloo()
 | 
				
			||||||
                ddp_opt.step()
 | 
					    @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
 | 
				
			||||||
 | 
					    def test_rnn(self):
 | 
				
			||||||
 | 
					        # This test is inspired by the bug reported in
 | 
				
			||||||
 | 
					        # https://github.com/pytorch/pytorch/issues/36268
 | 
				
			||||||
 | 
					        BATCH_SIZE = 12  # Divisible by 2, 3, 4
 | 
				
			||||||
 | 
					        INPUT_DIM = 256
 | 
				
			||||||
 | 
					        OUTPUT_DIM = 256
 | 
				
			||||||
 | 
					        HIDDEN_DIM = 256
 | 
				
			||||||
 | 
					        N_LAYERS = 3
 | 
				
			||||||
 | 
					        SEQ_LEN = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if check_allclose:
 | 
					        class Net(nn.Module):
 | 
				
			||||||
                for i, j in zip(ddp.parameters(), net.parameters()):
 | 
					            def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
 | 
				
			||||||
                    self.assertTrue(i.allclose(j))
 | 
					                super(Net, self).__init__()
 | 
				
			||||||
 | 
					                self.input_dim = input_dim
 | 
				
			||||||
 | 
					                self.hidden_dim = hidden_dim
 | 
				
			||||||
 | 
					                self.output_dim = output_dim
 | 
				
			||||||
 | 
					                self.hidden_layers = hidden_layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @requires_gloo()
 | 
					                self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True)
 | 
				
			||||||
        def test_cpu(self):
 | 
					                self.h2o = nn.Linear(hidden_dim, output_dim)
 | 
				
			||||||
            self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @requires_gloo()
 | 
					            def forward(self, x, y):
 | 
				
			||||||
        @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
 | 
					                self.lstm.flatten_parameters()
 | 
				
			||||||
        def test_cuda(self):
 | 
					                h_t, _ = self.lstm(x)
 | 
				
			||||||
            self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)])
 | 
					                output = self.h2o(h_t)
 | 
				
			||||||
 | 
					                loss = nn.functional.mse_loss(output, y)
 | 
				
			||||||
 | 
					                return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @requires_gloo()
 | 
					        net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0)
 | 
				
			||||||
        @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
 | 
					        inp = [
 | 
				
			||||||
        def test_rnn(self):
 | 
					            torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0),
 | 
				
			||||||
            # This test is inspired by the bug reported in
 | 
					            torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0)
 | 
				
			||||||
            # https://github.com/pytorch/pytorch/issues/36268
 | 
					        ]
 | 
				
			||||||
            BATCH_SIZE = 12  # Divisible by 2, 3, 4
 | 
					 | 
				
			||||||
            INPUT_DIM = 256
 | 
					 | 
				
			||||||
            OUTPUT_DIM = 256
 | 
					 | 
				
			||||||
            HIDDEN_DIM = 256
 | 
					 | 
				
			||||||
            N_LAYERS = 3
 | 
					 | 
				
			||||||
            SEQ_LEN = 100
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            class Net(nn.Module):
 | 
					        # Not checking result allclose as the parameter inconsistency exist
 | 
				
			||||||
                def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
 | 
					        # prior to this change. See #37079
 | 
				
			||||||
                    super(Net, self).__init__()
 | 
					        self._test_base(net, inp, check_allclose=False)
 | 
				
			||||||
                    self.input_dim = input_dim
 | 
					 | 
				
			||||||
                    self.hidden_dim = hidden_dim
 | 
					 | 
				
			||||||
                    self.output_dim = output_dim
 | 
					 | 
				
			||||||
                    self.hidden_layers = hidden_layers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True)
 | 
					 | 
				
			||||||
                    self.h2o = nn.Linear(hidden_dim, output_dim)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                def forward(self, x, y):
 | 
					 | 
				
			||||||
                    self.lstm.flatten_parameters()
 | 
					 | 
				
			||||||
                    h_t, _ = self.lstm(x)
 | 
					 | 
				
			||||||
                    output = self.h2o(h_t)
 | 
					 | 
				
			||||||
                    loss = nn.functional.mse_loss(output, y)
 | 
					 | 
				
			||||||
                    return loss
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0)
 | 
					 | 
				
			||||||
            inp = [
 | 
					 | 
				
			||||||
                torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0),
 | 
					 | 
				
			||||||
                torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0)
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Not checking result allclose as the parameter inconsistency exist
 | 
					 | 
				
			||||||
            # prior to this change. See #37079
 | 
					 | 
				
			||||||
            self._test_base(net, inp, check_allclose=False)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Skip dev-asan as torch + multiprocessing spawn have known issues
 | 
					# Skip dev-asan as torch + multiprocessing spawn have known issues
 | 
				
			||||||
 | 
				
			|||||||
@ -1,113 +0,0 @@
 | 
				
			|||||||
import os
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import tempfile
 | 
					 | 
				
			||||||
from functools import wraps
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
import torch.cuda
 | 
					 | 
				
			||||||
import torch.distributed as dist
 | 
					 | 
				
			||||||
from torch.testing._internal.common_utils import TEST_WITH_TSAN
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if not dist.is_available():
 | 
					 | 
				
			||||||
    print("Distributed not available, skipping tests", file=sys.stderr)
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from torch.testing._internal.common_utils import TestCase, find_free_port, run_tests
 | 
					 | 
				
			||||||
from torch.distributed.distributed_c10d import _get_default_group
 | 
					 | 
				
			||||||
from torch.testing._internal.distributed.distributed_test import (
 | 
					 | 
				
			||||||
    DistributedTest, TestDistBackend
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
torch.backends.cuda.matmul.allow_tf32 = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
CPP_EXTENSIONS_WARNING = """
 | 
					 | 
				
			||||||
Ninja (https://ninja-build.org) must be available to run C++ extensions tests,
 | 
					 | 
				
			||||||
but it could not be found. Install ninja with `pip install ninja`
 | 
					 | 
				
			||||||
or `conda install ninja`.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
BACKEND = os.environ["BACKEND"]
 | 
					 | 
				
			||||||
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def skip_if_no_ninja(func):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @wraps(func)
 | 
					 | 
				
			||||||
    def wrapper(*args, **kwargs):
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            import torch.utils.cpp_extension
 | 
					 | 
				
			||||||
            torch.utils.cpp_extension.verify_ninja_availability()
 | 
					 | 
				
			||||||
        except RuntimeError:
 | 
					 | 
				
			||||||
            print(CPP_EXTENSIONS_WARNING)
 | 
					 | 
				
			||||||
            return 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return func(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return wrapper
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if TEST_WITH_TSAN:
 | 
					 | 
				
			||||||
    print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr)
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if BACKEND == "gloo" or BACKEND == "nccl":
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    class TestDistBackendWithFork(TestDistBackend, DistributedTest._DistTestBase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def setUp(self):
 | 
					 | 
				
			||||||
            super().setUp()
 | 
					 | 
				
			||||||
            self._fork_processes()
 | 
					 | 
				
			||||||
            torch.backends.cudnn.flags(allow_tf32=False).__enter__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
elif BACKEND == "mpi":
 | 
					 | 
				
			||||||
    WORLD_SIZE = os.environ["WORLD_SIZE"]
 | 
					 | 
				
			||||||
    dist.init_process_group(init_method=INIT_METHOD, backend="mpi")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    class TestMPIWithFork(TestCase, DistributedTest._DistTestBase):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
elif BACKEND == "test":
 | 
					 | 
				
			||||||
    class TestBackendDynamicLoad(TestCase):
 | 
					 | 
				
			||||||
        def setUp(self):
 | 
					 | 
				
			||||||
            super(TestBackendDynamicLoad, self).setUp()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def _load_test_backend(self):
 | 
					 | 
				
			||||||
            temp_dir = tempfile.mkdtemp()
 | 
					 | 
				
			||||||
            src = "{}/../cpp_extensions/cpp_c10d_extension.cpp".format(os.path.abspath(os.path.dirname(__file__)))
 | 
					 | 
				
			||||||
            extension = torch.utils.cpp_extension.load(
 | 
					 | 
				
			||||||
                name="torch_test",
 | 
					 | 
				
			||||||
                sources=[src],
 | 
					 | 
				
			||||||
                build_directory=temp_dir
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        @skip_if_no_ninja
 | 
					 | 
				
			||||||
        def test_backend_apis(self):
 | 
					 | 
				
			||||||
            self._load_test_backend()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            os.environ['WORLD_SIZE'] = '1'
 | 
					 | 
				
			||||||
            os.environ['MASTER_ADDR'] = '127.0.0.1'
 | 
					 | 
				
			||||||
            os.environ['MASTER_PORT'] = str(find_free_port())
 | 
					 | 
				
			||||||
            os.environ['RANK'] = '0'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            dist.init_process_group(backend='test', init_method='env://', world_size=1, rank=0)
 | 
					 | 
				
			||||||
            self.assertEqual(dist.get_rank(), 0)
 | 
					 | 
				
			||||||
            self.assertEqual(dist.get_world_size(), 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            process_group = _get_default_group()
 | 
					 | 
				
			||||||
            work = process_group.allreduce([torch.rand(1), torch.rand(1)])
 | 
					 | 
				
			||||||
            self.assertTrue(work.wait())
 | 
					 | 
				
			||||||
            self.assertTrue(work.is_completed())
 | 
					 | 
				
			||||||
            self.assertTrue(work.is_success())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            work = process_group.broadcast([torch.rand(1)])
 | 
					 | 
				
			||||||
            self.assertTrue(work.wait())
 | 
					 | 
				
			||||||
            self.assertTrue(work.is_completed())
 | 
					 | 
				
			||||||
            self.assertTrue(work.is_success())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            dist.destroy_process_group()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    assert (
 | 
					 | 
				
			||||||
        not torch.cuda._initialized
 | 
					 | 
				
			||||||
    ), "test_distributed must not have initialized CUDA context on main process"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    run_tests()
 | 
					 | 
				
			||||||
@ -6,7 +6,7 @@ import time
 | 
				
			|||||||
from typing import List
 | 
					from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from torch.testing._internal.common_distributed import requires_nccl, create_tcp_store
 | 
					from torch.testing._internal.common_distributed import requires_nccl, create_tcp_store
 | 
				
			||||||
from torch.testing._internal.common_utils import load_tests, TEST_WITH_TSAN, run_tests, sandcastle_skip_if
 | 
					from torch.testing._internal.common_utils import load_tests, run_tests, sandcastle_skip_if
 | 
				
			||||||
from torch.testing._internal.jit_utils import JitTestCase
 | 
					from torch.testing._internal.jit_utils import JitTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# load_tests from common_utils is used to automatically filter tests for
 | 
					# load_tests from common_utils is used to automatically filter tests for
 | 
				
			||||||
@ -29,10 +29,6 @@ def unique_process_group_name(prefix):
 | 
				
			|||||||
    now = int(time.time() * 1000)
 | 
					    now = int(time.time() * 1000)
 | 
				
			||||||
    return "%s_%d" % (prefix, now)
 | 
					    return "%s_%d" % (prefix, now)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TEST_WITH_TSAN:
 | 
					 | 
				
			||||||
    print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr)
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ProcessGroupNCCLJitTest(JitTestCase):
 | 
					class ProcessGroupNCCLJitTest(JitTestCase):
 | 
				
			||||||
    MAIN_PROCESS_RANK = 0
 | 
					    MAIN_PROCESS_RANK = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,6 @@ if not dist.is_available():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    TestCase,
 | 
					    TestCase,
 | 
				
			||||||
    run_tests,
 | 
					    run_tests,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -25,10 +24,6 @@ if TEST_WITH_DEV_DBG_ASAN:
 | 
				
			|||||||
    print("Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr)
 | 
					    print("Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr)
 | 
				
			||||||
    sys.exit(0)
 | 
					    sys.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TEST_WITH_TSAN:
 | 
					 | 
				
			||||||
    print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr)
 | 
					 | 
				
			||||||
    sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestDistributedLaunch(TestCase):
 | 
					class TestDistributedLaunch(TestCase):
 | 
				
			||||||
    def test_launch_user_script(self):
 | 
					    def test_launch_user_script(self):
 | 
				
			||||||
        nnodes = 1
 | 
					        nnodes = 1
 | 
				
			||||||
@ -41,7 +36,7 @@ class TestDistributedLaunch(TestCase):
 | 
				
			|||||||
            f"--nnodes={nnodes}",
 | 
					            f"--nnodes={nnodes}",
 | 
				
			||||||
            f"--nproc_per_node={nproc_per_node}",
 | 
					            f"--nproc_per_node={nproc_per_node}",
 | 
				
			||||||
            "--monitor_interval=1",
 | 
					            "--monitor_interval=1",
 | 
				
			||||||
            "--start_method=fork",
 | 
					            "--start_method=spawn",
 | 
				
			||||||
            "--master_addr=localhost",
 | 
					            "--master_addr=localhost",
 | 
				
			||||||
            f"--master_port={master_port}",
 | 
					            f"--master_port={master_port}",
 | 
				
			||||||
            "--node_rank=0",
 | 
					            "--node_rank=0",
 | 
				
			||||||
 | 
				
			|||||||
@ -20,7 +20,6 @@ from torch.testing._internal.common_distributed import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from torch.testing._internal.common_utils import (
 | 
					from torch.testing._internal.common_utils import (
 | 
				
			||||||
    run_tests,
 | 
					    run_tests,
 | 
				
			||||||
    TEST_WITH_TSAN,
 | 
					 | 
				
			||||||
    TEST_WITH_DEV_DBG_ASAN,
 | 
					    TEST_WITH_DEV_DBG_ASAN,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,11 +27,7 @@ from torch.testing._internal.common_utils import (
 | 
				
			|||||||
class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
 | 
					class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        super(AbstractProcessGroupWrapperTest, self).setUp()
 | 
					        super(AbstractProcessGroupWrapperTest, self).setUp()
 | 
				
			||||||
        # For Windows platform, Python does not support fork, change it to spawn here.
 | 
					        self._spawn_processes()
 | 
				
			||||||
        if sys.platform == "win32":
 | 
					 | 
				
			||||||
            self._spawn_processes()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._fork_processes()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _validate_error(self, exception, op_type, rank, tensor):
 | 
					    def _validate_error(self, exception, op_type, rank, tensor):
 | 
				
			||||||
        err = str(exception)
 | 
					        err = str(exception)
 | 
				
			||||||
@ -291,91 +286,89 @@ if not TEST_WITH_DEV_DBG_ASAN:
 | 
				
			|||||||
            self._test_collective_shape_mismatch(pg, use_cuda=True)
 | 
					            self._test_collective_shape_mismatch(pg, use_cuda=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TSAN is not fork-safe since we're forking in a multi-threaded environment
 | 
					@requires_gloo()
 | 
				
			||||||
if not TEST_WITH_TSAN:
 | 
					class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
 | 
				
			||||||
    @requires_gloo()
 | 
					    def setUp(self):
 | 
				
			||||||
    class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
 | 
					        super(ProcessGroupGlooWrapperTest, self).setUp()
 | 
				
			||||||
        def setUp(self):
 | 
					 | 
				
			||||||
            super(ProcessGroupGlooWrapperTest, self).setUp()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def opts(self, threads=2, timeout=10.0):
 | 
					    def opts(self, threads=2, timeout=10.0):
 | 
				
			||||||
            opts = c10d.ProcessGroupGloo._Options()
 | 
					        opts = c10d.ProcessGroupGloo._Options()
 | 
				
			||||||
            opts._timeout = timeout
 | 
					        opts._timeout = timeout
 | 
				
			||||||
            opts._devices = [create_device(interface=LOOPBACK)]
 | 
					        opts._devices = [create_device(interface=LOOPBACK)]
 | 
				
			||||||
            opts._threads = threads
 | 
					        opts._threads = threads
 | 
				
			||||||
            return opts
 | 
					        return opts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
 | 
					    def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
 | 
				
			||||||
            store = c10d.FileStore(self.file_name, self.world_size)
 | 
					        store = c10d.FileStore(self.file_name, self.world_size)
 | 
				
			||||||
            c10d.init_process_group(
 | 
					        c10d.init_process_group(
 | 
				
			||||||
                backend="gloo", rank=self.rank, world_size=self.world_size, store=store
 | 
					            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if with_new_group:
 | 
				
			||||||
 | 
					            pg = c10d.new_group(backend="gloo")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _pg = c10d.ProcessGroupGloo(
 | 
				
			||||||
 | 
					                store, self.rank, self.world_size, self.opts(timeout=timeout)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            if with_new_group:
 | 
					            pg = c10d._create_process_group_wrapper(
 | 
				
			||||||
                pg = c10d.new_group(backend="gloo")
 | 
					                _pg,
 | 
				
			||||||
            else:
 | 
					                "unused",
 | 
				
			||||||
                _pg = c10d.ProcessGroupGloo(
 | 
					                store,
 | 
				
			||||||
                    store, self.rank, self.world_size, self.opts(timeout=timeout)
 | 
					                self.rank,
 | 
				
			||||||
                )
 | 
					                self.world_size,
 | 
				
			||||||
                pg = c10d._create_process_group_wrapper(
 | 
					                timeout=timeout,
 | 
				
			||||||
                    _pg,
 | 
					            )
 | 
				
			||||||
                    "unused",
 | 
					        return pg
 | 
				
			||||||
                    store,
 | 
					 | 
				
			||||||
                    self.rank,
 | 
					 | 
				
			||||||
                    self.world_size,
 | 
					 | 
				
			||||||
                    timeout=timeout,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            return pg
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_collective_hang(self):
 | 
					    def test_collective_hang(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(timeout=2.0)
 | 
					        pg = self._create_wrapper_pg(timeout=2.0)
 | 
				
			||||||
            self._test_collective_hang(pg)
 | 
					        self._test_collective_hang(pg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # NOTE: these tests are separated by debug level instead of combined into
 | 
					    # NOTE: these tests are separated by debug level instead of combined into
 | 
				
			||||||
        # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
 | 
					    # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
 | 
				
			||||||
        # combined after that is resolved.
 | 
					    # combined after that is resolved.
 | 
				
			||||||
        @with_dist_debug_levels(levels=["DETAIL"])
 | 
					    @with_dist_debug_levels(levels=["DETAIL"])
 | 
				
			||||||
        def test_collectives_op_mismatch_debug_mode(self):
 | 
					    def test_collectives_op_mismatch_debug_mode(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=True)
 | 
					        pg = self._create_wrapper_pg(with_new_group=True)
 | 
				
			||||||
            self._test_collectives_op_mismatch(pg)
 | 
					        self._test_collectives_op_mismatch(pg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @with_dist_debug_levels(levels=["OFF"])
 | 
					    @with_dist_debug_levels(levels=["OFF"])
 | 
				
			||||||
        def test_collectives_op_mismatch(self):
 | 
					    def test_collectives_op_mismatch(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=False)
 | 
					        pg = self._create_wrapper_pg(with_new_group=False)
 | 
				
			||||||
            self._test_collectives_op_mismatch(pg)
 | 
					        self._test_collectives_op_mismatch(pg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @with_dist_debug_levels(levels=["DETAIL"])
 | 
					    @with_dist_debug_levels(levels=["DETAIL"])
 | 
				
			||||||
        def test_collective_shape_mismatch_debug_mode(self):
 | 
					    def test_collective_shape_mismatch_debug_mode(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=True)
 | 
					        pg = self._create_wrapper_pg(with_new_group=True)
 | 
				
			||||||
            self._test_collective_shape_mismatch(pg)
 | 
					        self._test_collective_shape_mismatch(pg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @with_dist_debug_levels(levels=["OFF"])
 | 
					    @with_dist_debug_levels(levels=["OFF"])
 | 
				
			||||||
        def test_collective_shape_mismatch(self):
 | 
					    def test_collective_shape_mismatch(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=False)
 | 
					        pg = self._create_wrapper_pg(with_new_group=False)
 | 
				
			||||||
            self._test_collective_shape_mismatch(pg)
 | 
					        self._test_collective_shape_mismatch(pg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @skip_if_lt_x_gpu(4)
 | 
					    @skip_if_lt_x_gpu(4)
 | 
				
			||||||
        @with_dist_debug_levels(levels=["DETAIL"])
 | 
					    @with_dist_debug_levels(levels=["DETAIL"])
 | 
				
			||||||
        def test_collectives_op_mismatch_cuda_debug_mode(self):
 | 
					    def test_collectives_op_mismatch_cuda_debug_mode(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=True)
 | 
					        pg = self._create_wrapper_pg(with_new_group=True)
 | 
				
			||||||
            self._test_collectives_op_mismatch(pg, use_cuda=True)
 | 
					        self._test_collectives_op_mismatch(pg, use_cuda=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @skip_if_lt_x_gpu(4)
 | 
					    @skip_if_lt_x_gpu(4)
 | 
				
			||||||
        @with_dist_debug_levels(levels=["OFF"])
 | 
					    @with_dist_debug_levels(levels=["OFF"])
 | 
				
			||||||
        def test_collectives_op_mismatch_cuda(self):
 | 
					    def test_collectives_op_mismatch_cuda(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=False)
 | 
					        pg = self._create_wrapper_pg(with_new_group=False)
 | 
				
			||||||
            self._test_collectives_op_mismatch(pg, use_cuda=True)
 | 
					        self._test_collectives_op_mismatch(pg, use_cuda=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @skip_if_lt_x_gpu(4)
 | 
					    @skip_if_lt_x_gpu(4)
 | 
				
			||||||
        @with_dist_debug_levels(levels=["DETAIL"])
 | 
					    @with_dist_debug_levels(levels=["DETAIL"])
 | 
				
			||||||
        def test_collective_shape_mismatch_cuda_debug_mode(self):
 | 
					    def test_collective_shape_mismatch_cuda_debug_mode(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=True)
 | 
					        pg = self._create_wrapper_pg(with_new_group=True)
 | 
				
			||||||
            self._test_collective_shape_mismatch(pg, use_cuda=True)
 | 
					        self._test_collective_shape_mismatch(pg, use_cuda=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @skip_if_lt_x_gpu(4)
 | 
					    @skip_if_lt_x_gpu(4)
 | 
				
			||||||
        @with_dist_debug_levels(levels=["OFF"])
 | 
					    @with_dist_debug_levels(levels=["OFF"])
 | 
				
			||||||
        def test_collective_shape_mismatch_cuda(self):
 | 
					    def test_collective_shape_mismatch_cuda(self):
 | 
				
			||||||
            pg = self._create_wrapper_pg(with_new_group=False)
 | 
					        pg = self._create_wrapper_pg(with_new_group=False)
 | 
				
			||||||
            self._test_collective_shape_mismatch(pg, use_cuda=True)
 | 
					        self._test_collective_shape_mismatch(pg, use_cuda=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
@ -65,7 +65,6 @@ TESTS = [
 | 
				
			|||||||
    'test_dataloader',
 | 
					    'test_dataloader',
 | 
				
			||||||
    'test_datapipe',
 | 
					    'test_datapipe',
 | 
				
			||||||
    'distributed/test_data_parallel',
 | 
					    'distributed/test_data_parallel',
 | 
				
			||||||
    'distributed/test_distributed_fork',
 | 
					 | 
				
			||||||
    'distributed/test_distributed_spawn',
 | 
					    'distributed/test_distributed_spawn',
 | 
				
			||||||
    'distributions/test_constraints',
 | 
					    'distributions/test_constraints',
 | 
				
			||||||
    'distributions/test_distributions',
 | 
					    'distributions/test_distributions',
 | 
				
			||||||
@ -212,7 +211,6 @@ WINDOWS_BLOCKLIST = [
 | 
				
			|||||||
    'distributed/rpc/test_faulty_agent',
 | 
					    'distributed/rpc/test_faulty_agent',
 | 
				
			||||||
    'distributed/rpc/test_tensorpipe_agent',
 | 
					    'distributed/rpc/test_tensorpipe_agent',
 | 
				
			||||||
    'distributed/rpc/cuda/test_tensorpipe_agent',
 | 
					    'distributed/rpc/cuda/test_tensorpipe_agent',
 | 
				
			||||||
    'distributed/test_distributed_fork',
 | 
					 | 
				
			||||||
    'distributed/pipeline/sync/skip/test_api',
 | 
					    'distributed/pipeline/sync/skip/test_api',
 | 
				
			||||||
    'distributed/pipeline/sync/skip/test_gpipe',
 | 
					    'distributed/pipeline/sync/skip/test_gpipe',
 | 
				
			||||||
    'distributed/pipeline/sync/skip/test_inspect_skip_layout',
 | 
					    'distributed/pipeline/sync/skip/test_inspect_skip_layout',
 | 
				
			||||||
@ -294,7 +292,6 @@ TARGET_DET_LIST = [
 | 
				
			|||||||
    'test_testing',
 | 
					    'test_testing',
 | 
				
			||||||
    'test_view_ops',
 | 
					    'test_view_ops',
 | 
				
			||||||
    'distributed/nn/jit/test_instantiator',
 | 
					    'distributed/nn/jit/test_instantiator',
 | 
				
			||||||
    'distributed/test_distributed_fork',
 | 
					 | 
				
			||||||
    'distributed/rpc/test_tensorpipe_agent',
 | 
					    'distributed/rpc/test_tensorpipe_agent',
 | 
				
			||||||
    'distributed/rpc/cuda/test_tensorpipe_agent',
 | 
					    'distributed/rpc/cuda/test_tensorpipe_agent',
 | 
				
			||||||
    'distributed/algorithms/ddp_comm_hooks/test_ddp_hooks',
 | 
					    'distributed/algorithms/ddp_comm_hooks/test_ddp_hooks',
 | 
				
			||||||
@ -576,7 +573,7 @@ def test_distributed(test_module, test_directory, options):
 | 
				
			|||||||
            os.environ['INIT_METHOD'] = 'env://'
 | 
					            os.environ['INIT_METHOD'] = 'env://'
 | 
				
			||||||
            os.environ.update(env_vars)
 | 
					            os.environ.update(env_vars)
 | 
				
			||||||
            if with_init_file:
 | 
					            if with_init_file:
 | 
				
			||||||
                if test_module in ["test_distributed_fork", "test_distributed_spawn"]:
 | 
					                if test_module == "test_distributed_spawn":
 | 
				
			||||||
                    init_method = f'{FILE_SCHEMA}{tmp_dir}/'
 | 
					                    init_method = f'{FILE_SCHEMA}{tmp_dir}/'
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    init_method = f'{FILE_SCHEMA}{tmp_dir}/shared_init_file'
 | 
					                    init_method = f'{FILE_SCHEMA}{tmp_dir}/shared_init_file'
 | 
				
			||||||
@ -611,7 +608,6 @@ CUSTOM_HANDLERS = {
 | 
				
			|||||||
    'test_cuda_primary_ctx': test_cuda_primary_ctx,
 | 
					    'test_cuda_primary_ctx': test_cuda_primary_ctx,
 | 
				
			||||||
    'test_cpp_extensions_aot_no_ninja': test_cpp_extensions_aot_no_ninja,
 | 
					    'test_cpp_extensions_aot_no_ninja': test_cpp_extensions_aot_no_ninja,
 | 
				
			||||||
    'test_cpp_extensions_aot_ninja': test_cpp_extensions_aot_ninja,
 | 
					    'test_cpp_extensions_aot_ninja': test_cpp_extensions_aot_ninja,
 | 
				
			||||||
    'distributed/test_distributed_fork': test_distributed,
 | 
					 | 
				
			||||||
    'distributed/test_distributed_spawn': test_distributed,
 | 
					    'distributed/test_distributed_spawn': test_distributed,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,6 @@ class DeterminationTest(unittest.TestCase):
 | 
				
			|||||||
        "test_jit_profiling",
 | 
					        "test_jit_profiling",
 | 
				
			||||||
        "test_jit",
 | 
					        "test_jit",
 | 
				
			||||||
        "test_torch",
 | 
					        "test_torch",
 | 
				
			||||||
        "distributed/test_distributed_fork",
 | 
					 | 
				
			||||||
        "distributed/test_distributed_spawn",
 | 
					        "distributed/test_distributed_spawn",
 | 
				
			||||||
        "test_cpp_extensions_aot_ninja",
 | 
					        "test_cpp_extensions_aot_ninja",
 | 
				
			||||||
        "test_cpp_extensions_aot_no_ninja",
 | 
					        "test_cpp_extensions_aot_no_ninja",
 | 
				
			||||||
@ -104,7 +103,6 @@ class DeterminationTest(unittest.TestCase):
 | 
				
			|||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            self.determined_tests(["torch/utils/cpp_extension.py"]),
 | 
					            self.determined_tests(["torch/utils/cpp_extension.py"]),
 | 
				
			||||||
            [
 | 
					            [
 | 
				
			||||||
                "distributed/test_distributed_fork",
 | 
					 | 
				
			||||||
                "test_cpp_extensions_aot_ninja",
 | 
					                "test_cpp_extensions_aot_ninja",
 | 
				
			||||||
                "test_cpp_extensions_aot_no_ninja",
 | 
					                "test_cpp_extensions_aot_no_ninja",
 | 
				
			||||||
                "test_utils",
 | 
					                "test_utils",
 | 
				
			||||||
 | 
				
			|||||||
@ -630,7 +630,6 @@ class TestFile:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def append(self, test_case: TestCase, test_type: str) -> None:
 | 
					    def append(self, test_case: TestCase, test_type: str) -> None:
 | 
				
			||||||
        is_multi_test = self.name == 'test_cpp_extensions_aot' or \
 | 
					        is_multi_test = self.name == 'test_cpp_extensions_aot' or \
 | 
				
			||||||
            self.name == 'distributed/test_distributed_fork' or \
 | 
					 | 
				
			||||||
            self.name == 'distributed/test_distributed_spawn' or \
 | 
					            self.name == 'distributed/test_distributed_spawn' or \
 | 
				
			||||||
            self.name == 'distributed/test_c10d_gloo' or \
 | 
					            self.name == 'distributed/test_c10d_gloo' or \
 | 
				
			||||||
            self.name == 'cpp'  # The caffe2 cpp tests spawn duplicate test cases as well.
 | 
					            self.name == 'cpp'  # The caffe2 cpp tests spawn duplicate test cases as well.
 | 
				
			||||||
 | 
				
			|||||||
@ -85,7 +85,6 @@ python test/distributed/test_store.py
 | 
				
			|||||||
python test/distributed/test_pg_wrapper.py
 | 
					python test/distributed/test_pg_wrapper.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Run distributed tests, including tests for Distributed Data Parallel.
 | 
					# Run distributed tests, including tests for Distributed Data Parallel.
 | 
				
			||||||
python test/run_test.py --verbose -i distributed/test_distributed_fork
 | 
					 | 
				
			||||||
python test/run_test.py --verbose -i distributed/test_distributed_spawn
 | 
					python test/run_test.py --verbose -i distributed/test_distributed_spawn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Run the RPC test suite for the TensorPipeAgent.
 | 
					# Run the RPC test suite for the TensorPipeAgent.
 | 
				
			||||||
 | 
				
			|||||||
@ -9,6 +9,7 @@ import time
 | 
				
			|||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
import types
 | 
					import types
 | 
				
			||||||
import unittest
 | 
					import unittest
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
from contextlib import contextmanager
 | 
					from contextlib import contextmanager
 | 
				
			||||||
from datetime import timedelta
 | 
					from datetime import timedelta
 | 
				
			||||||
from enum import Enum
 | 
					from enum import Enum
 | 
				
			||||||
@ -468,6 +469,10 @@ class MultiProcessTestCase(TestCase):
 | 
				
			|||||||
            self.processes.append(process)
 | 
					            self.processes.append(process)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _fork_processes(self) -> None:
 | 
					    def _fork_processes(self) -> None:
 | 
				
			||||||
 | 
					        warnings.warn(
 | 
				
			||||||
 | 
					            "Fork based multiprocessing is dangerous and should not"
 | 
				
			||||||
 | 
					            " be used, for tests with ASAN consider using opt-asan",
 | 
				
			||||||
 | 
					            DeprecationWarning)
 | 
				
			||||||
        proc = torch.multiprocessing.get_context("fork").Process
 | 
					        proc = torch.multiprocessing.get_context("fork").Process
 | 
				
			||||||
        self._start_processes(proc)
 | 
					        self._start_processes(proc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user