mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			51 Commits
		
	
	
		
			new-codege
			...
			sqzhang_fl
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 972b8060bd | |||
| 3e7683ae18 | |||
| f2e9ec2dc5 | |||
| dde4324d8e | |||
| 94c079104d | |||
| a6afee6d94 | |||
| d092857531 | |||
| 6aad5e444a | |||
| c54ce9313b | |||
| 1fe59f4ef7 | |||
| e693fb2bb1 | |||
| 4fe510baf6 | |||
| 7c507b78c4 | |||
| 0019901601 | |||
| 18be18535b | |||
| 2729367313 | |||
| 33537aae24 | |||
| dcdb1337dd | |||
| 9cf0f2bd59 | |||
| 1d2e877c05 | |||
| f27b979b0c | |||
| f30d6047ad | |||
| 75311510ef | |||
| dbd6094d05 | |||
| 397b9d47e9 | |||
| 36a01a8ab9 | |||
| ee336cf58a | |||
| a9e2e745d7 | |||
| ab4df89eea | |||
| 9d02ebe876 | |||
| b61e01cce9 | |||
| f7ce61ba53 | |||
| e7bae15ab1 | |||
| e71b422908 | |||
| 389940ce60 | |||
| b2237a7c85 | |||
| ef5dfe3f3e | |||
| e303dc3c08 | |||
| 265efad2de | |||
| 60f0455905 | |||
| 4898313791 | |||
| f4da9adf6b | |||
| 8f7f35273e | |||
| 44ec9612ed | |||
| 4d3bea2b29 | |||
| 0bcdddc3c1 | |||
| 28b6220312 | |||
| 210b7b65e2 | |||
| 4da10b5cd3 | |||
| f09763814f | |||
| 80923ed5a6 | 
@ -1732,7 +1732,7 @@ if(BUILD_TEST)
 | 
			
		||||
  foreach(test_src ${Caffe2_CPU_TEST_SRCS})
 | 
			
		||||
    get_filename_component(test_name ${test_src} NAME_WE)
 | 
			
		||||
    add_executable(${test_name} "${test_src}")
 | 
			
		||||
    target_link_libraries(${test_name} torch_library gtest_main)
 | 
			
		||||
    target_link_libraries(${test_name} torch_library gtest_main stdc++)
 | 
			
		||||
    target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
 | 
			
		||||
    target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
 | 
			
		||||
    target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
 | 
			
		||||
 | 
			
		||||
@ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) {
 | 
			
		||||
// Extend the nested class outside the parent class
 | 
			
		||||
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  TestDebugInfoWriter() : DebugInfoWriter(0) {}
 | 
			
		||||
  TestDebugInfoWriter(std::string namePrefix)
 | 
			
		||||
      : DebugInfoWriter(namePrefix, 0) {}
 | 
			
		||||
 | 
			
		||||
  void write(const std::string& ncclTrace) override {
 | 
			
		||||
    traces_.assign(ncclTrace.begin(), ncclTrace.end());
 | 
			
		||||
@ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
 | 
			
		||||
  // The storer here is very similar to the fallback storer.
 | 
			
		||||
  // The only difference is that we are storing traces also in memory for
 | 
			
		||||
  // validation.
 | 
			
		||||
  std::string fileNamePrefix = c10d::getCvarString(
 | 
			
		||||
      {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
  std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
 | 
			
		||||
      std::make_unique<TestDebugInfoWriter>();
 | 
			
		||||
      std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
 | 
			
		||||
  std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
 | 
			
		||||
  pg.registerDebugInfoWriter(std::move(wrterForTestPtr));
 | 
			
		||||
  c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
 | 
			
		||||
 | 
			
		||||
  // Normal collective case.
 | 
			
		||||
  auto work = pg.allreduce(tensors_);
 | 
			
		||||
@ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
 | 
			
		||||
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
 | 
			
		||||
 protected:
 | 
			
		||||
  void SetUp() override {
 | 
			
		||||
    // TODO (kwen2501)
 | 
			
		||||
    GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
 | 
			
		||||
                 << "will rewrite them after refactoring Work queues.";
 | 
			
		||||
    ProcessGroupNCCLErrorsTest::SetUp();
 | 
			
		||||
    std::string timeInterval = std::to_string(heartBeatIntervalInSec);
 | 
			
		||||
    ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
 | 
			
		||||
 | 
			
		||||
@ -3542,11 +3542,12 @@ class SparseCollective(MultiProcessTestCase):
 | 
			
		||||
class NCCLTraceTestBase(MultiProcessTestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'
 | 
			
		||||
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'  # see 'timing_enabled' parametrized tests
 | 
			
		||||
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
 | 
			
		||||
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
 | 
			
		||||
        self.tempdir = tempfile.TemporaryDirectory()
 | 
			
		||||
        os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
 | 
			
		||||
        os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
 | 
			
		||||
        self._spawn_processes()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -3617,28 +3618,49 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_short(self):
 | 
			
		||||
    @parametrize("timing_enabled", [True, False])
 | 
			
		||||
    def test_short(self, timing_enabled):
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            return
 | 
			
		||||
        pg = self._create_process_group_nccl()
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            pg._enable_collectives_timing()
 | 
			
		||||
        device = self.local_device
 | 
			
		||||
        a = torch.full((3, 4), float(self.rank), device=device)
 | 
			
		||||
        for i in range(2):
 | 
			
		||||
            f = pg.allreduce(a)
 | 
			
		||||
        f.wait()
 | 
			
		||||
        torch.cuda.synchronize(device=device)
 | 
			
		||||
 | 
			
		||||
        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
 | 
			
		||||
        time.sleep(1)
 | 
			
		||||
 | 
			
		||||
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
        ver = t['version']
 | 
			
		||||
        self.assertEqual(ver, "1.1")
 | 
			
		||||
        t = t['entries']
 | 
			
		||||
        self.assertEqual(len(t), 2)
 | 
			
		||||
        last = t[-1]
 | 
			
		||||
        self.assertEqual(last['state'], 'completed')
 | 
			
		||||
        s = last['time_discovered_started_ns']
 | 
			
		||||
        f = last['time_discovered_completed_ns']
 | 
			
		||||
        self.assertIsNotNone(f)
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            self.assertIsNotNone(s)
 | 
			
		||||
            self.assertTrue(s <= f)
 | 
			
		||||
        self.assertIn('test_c10d_nccl.py', str(last['frames']))
 | 
			
		||||
        self.assertEqual(last['input_sizes'], ((3, 4),))
 | 
			
		||||
        self.assertEqual(last['output_sizes'], ((3, 4),))
 | 
			
		||||
        self.assertEqual(last['seq_id'], 2)
 | 
			
		||||
        now = datetime.now()
 | 
			
		||||
        event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000)
 | 
			
		||||
        event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000)
 | 
			
		||||
        before_test = now - timedelta(minutes=1)
 | 
			
		||||
        self.assertTrue(before_test < event_created_time < now)
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            # very loose bounds, measured 0.036 ms on devgpu
 | 
			
		||||
            self.assertTrue(0 < last['duration_ms'] < 100)
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertTrue("duration_ms" not in last)
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
@ -3698,9 +3720,11 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
        f.wait()
 | 
			
		||||
        torch.cuda.synchronize(device=device)
 | 
			
		||||
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
        t = t['entries']
 | 
			
		||||
        self.assertEqual(len(t), 10)
 | 
			
		||||
        first = t[0]
 | 
			
		||||
        last = t[-1]
 | 
			
		||||
        self.assertEqual(last['profiling_name'], 'nccl:all_reduce')
 | 
			
		||||
        self.assertEqual(last['state'], 'completed')
 | 
			
		||||
        self.assertIn('test_c10d_nccl.py', str(last['frames']))
 | 
			
		||||
        self.assertEqual(last['input_sizes'], ((3, 4),))
 | 
			
		||||
@ -3732,6 +3756,8 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
                pg.allreduce(a).wait()
 | 
			
		||||
            e.synchronize()
 | 
			
		||||
            t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
            t = t['entries']
 | 
			
		||||
            self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                self.assertEqual(t[-1]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[-1]['state'], 'completed')
 | 
			
		||||
@ -3773,12 +3799,15 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
                # give the other thread some time to fill the cuda buffer
 | 
			
		||||
                time.sleep(5)
 | 
			
		||||
                t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
 | 
			
		||||
                if self.rank == 0:
 | 
			
		||||
                    self.assertEqual(t[-1]['seq_id'], 1)
 | 
			
		||||
                    self.assertEqual(t[-1]['state'], 'completed')
 | 
			
		||||
                else:
 | 
			
		||||
                    self.assertEqual(t[-1]['seq_id'], 2)
 | 
			
		||||
                    self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled))
 | 
			
		||||
                    self.assertIsNone(t[-1]['time_discovered_completed_ns'])
 | 
			
		||||
                # this will eventually cause the missing rank 0
 | 
			
		||||
                # to continue which will unblock the non-zero ranks
 | 
			
		||||
                self.parent.send('next')
 | 
			
		||||
@ -3832,6 +3861,10 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    @parametrize("timing_enabled", [True, False])
 | 
			
		||||
    def test_timeout_dumps(self, timing_enabled):
 | 
			
		||||
        # dump on heartbeatmonitor thread
 | 
			
		||||
        os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000'
 | 
			
		||||
        # need rank0 to crash before looking for its output file
 | 
			
		||||
        os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1'
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for rank0 to crash before looking for its output file
 | 
			
		||||
@ -3839,6 +3872,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            self.assertEqual(self._wait_process(0, timeout=90), -6)
 | 
			
		||||
            with open(self._trace_name(rank=0), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 2)
 | 
			
		||||
                self.assertEqual(t[0]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[0]['state'], 'completed')
 | 
			
		||||
@ -3868,7 +3902,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
 | 
			
		||||
instantiate_parametrized_tests(NCCLTraceTest)
 | 
			
		||||
 | 
			
		||||
class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
    def _check_return_codes(self, elapsed_time):
 | 
			
		||||
        # the base test infra assumes processes exit with matching return codes,
 | 
			
		||||
        # but we want rank0 to abort and rank1 to exit cleanly in this test
 | 
			
		||||
@ -3877,7 +3911,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_timeout_dumps_on_idle_ranks(self):
 | 
			
		||||
    def test_timeout_dumps_on_stuck_ranks(self):
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for both rank0 and 1 to crash before looking for both ranks' output
 | 
			
		||||
@ -3888,20 +3922,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
 | 
			
		||||
            with open(self._trace_name(rank=0), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 2)
 | 
			
		||||
            with open(self._trace_name(rank=1), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 1)
 | 
			
		||||
                self.assertEqual(t[0]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[0]['state'], 'completed')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Set heartbeat timeout to a shorter one (default timeout is 2 min).
 | 
			
		||||
        os.environ[
 | 
			
		||||
            "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"
 | 
			
		||||
        ] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}"
 | 
			
		||||
        pg = self._create_process_group_nccl()
 | 
			
		||||
 | 
			
		||||
        device = self.local_device
 | 
			
		||||
        with torch.cuda.device(device):
 | 
			
		||||
            a = torch.full((3, 4), float(self.rank), device=device)
 | 
			
		||||
@ -3910,12 +3941,68 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                pg.allreduce(a).wait()
 | 
			
		||||
 | 
			
		||||
            # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
 | 
			
		||||
            # rank 0 will get stuck, timeout and then signal a timeout to all ranks.
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
            # Force rank 1 to idle so that it also gets debug info dump triggered.
 | 
			
		||||
            if self.rank == 1:
 | 
			
		||||
                time.sleep(6)
 | 
			
		||||
                # Force rank 1 to idle so that it will eventually timeout as well after
 | 
			
		||||
                # getting the global signal to dump the debugging info.
 | 
			
		||||
                time.sleep(600)
 | 
			
		||||
 | 
			
		||||
class NcclErrorDumpTest(NCCLTraceTestBase):
 | 
			
		||||
    def _wait_process(self, rank, timeout):
 | 
			
		||||
        try:
 | 
			
		||||
            self.processes[rank].join(timeout)
 | 
			
		||||
            return self.processes[rank].exitcode
 | 
			
		||||
        except TimeoutError:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def _check_return_codes(self, elapsed_time):
 | 
			
		||||
        # the base test infra assumes processes exit with matching return codes,
 | 
			
		||||
        # but we want rank0 to abort with exception and rank1 to exit with exit 1
 | 
			
		||||
        self.assertEqual(self.processes[0].exitcode, -6)
 | 
			
		||||
        self.assertEqual(self.processes[1].exitcode, 1)
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
 | 
			
		||||
    @skip_if_lt_x_gpu(2)
 | 
			
		||||
    @skip_if_rocm
 | 
			
		||||
    def test_nccl_errors_dump(self):
 | 
			
		||||
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
 | 
			
		||||
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000'
 | 
			
		||||
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
 | 
			
		||||
        # need rank0 to dump before abort
 | 
			
		||||
        os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '5'
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for both rank0 and 1 to crash before looking for dump
 | 
			
		||||
            self.assertEqual(self._wait_process(0, timeout=90), -6)
 | 
			
		||||
            self.assertEqual(self._wait_process(1, timeout=90), 1)
 | 
			
		||||
            # verify that the trace file exists for rank0
 | 
			
		||||
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        store = c10d.FileStore(self.file_name, self.world_size)
 | 
			
		||||
        process_group = c10d.ProcessGroupNCCL(
 | 
			
		||||
            store,
 | 
			
		||||
            self.rank,
 | 
			
		||||
            self.world_size,
 | 
			
		||||
            timeout=timedelta(seconds=10),
 | 
			
		||||
        )
 | 
			
		||||
        process_group.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
            # expect an error to be raised
 | 
			
		||||
            with self.assertRaisesRegex(dist.DistBackendError, ""):
 | 
			
		||||
                # Block the current stream on the NCCL stream
 | 
			
		||||
                work.wait()
 | 
			
		||||
                # Run some GPU operations
 | 
			
		||||
                a = torch.rand(10).cuda(self.rank)
 | 
			
		||||
        elif self.rank == 1:
 | 
			
		||||
            # Clean up structures (ex: files for FileStore before going down)
 | 
			
		||||
            del process_group
 | 
			
		||||
            sys.exit(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    assert (
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ class StoreTestBase:
 | 
			
		||||
    def _create_store(self, i):
 | 
			
		||||
        raise RuntimeError("not implemented")
 | 
			
		||||
 | 
			
		||||
    def _test_set_get(self, fs):
 | 
			
		||||
    def _test_set_get_check(self, fs):
 | 
			
		||||
        fs.add("key", 1)
 | 
			
		||||
        fs.add("key", 2)
 | 
			
		||||
        fs.add("key", 3)
 | 
			
		||||
@ -90,14 +90,16 @@ class StoreTestBase:
 | 
			
		||||
        self.assertEqual(b"value1", fs.get("key1"))
 | 
			
		||||
        self.assertEqual(b"value2", fs.get("key2"))
 | 
			
		||||
        self.assertEqual(b"21", fs.get("key3"))
 | 
			
		||||
        self.assertTrue(fs.check(["key3"]))
 | 
			
		||||
        self.assertFalse(fs.check(["Randomkey3"]))
 | 
			
		||||
 | 
			
		||||
        fs.set("-key3", "7")
 | 
			
		||||
        self.assertEqual(b"7", fs.get("-key3"))
 | 
			
		||||
        fs.delete_key("-key3")
 | 
			
		||||
        self.assertEqual(fs.num_keys(), self.num_keys_total)
 | 
			
		||||
 | 
			
		||||
    def test_set_get(self):
 | 
			
		||||
        self._test_set_get(self._create_store())
 | 
			
		||||
    def test_set_get_check(self):
 | 
			
		||||
        self._test_set_get_check(self._create_store())
 | 
			
		||||
 | 
			
		||||
    def _test_compare_set(self, store):
 | 
			
		||||
        missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
 | 
			
		||||
@ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
 | 
			
		||||
    def num_keys_total(self):
 | 
			
		||||
        return 6
 | 
			
		||||
 | 
			
		||||
    def test_underlying_non_prefix_store(self):
 | 
			
		||||
        store = self._create_store()
 | 
			
		||||
        wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
 | 
			
		||||
        self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
 | 
			
		||||
        self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
 | 
			
		||||
 | 
			
		||||
class MyPythonStore(dist.Store):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Optional.h>
 | 
			
		||||
#include <nccl.h>
 | 
			
		||||
@ -175,14 +176,29 @@ std::string getNcclErrorDetailStr(
 | 
			
		||||
    c10::optional<std::string> processGroupFailureReason = c10::nullopt);
 | 
			
		||||
 | 
			
		||||
// Write NCCL debug info to local disk or any storage users define.
 | 
			
		||||
// There are some constrains we set for the debug info writer:
 | 
			
		||||
// 1. The writer should only be registered once.
 | 
			
		||||
// 2. Once registered, users cannot change it including un-register.
 | 
			
		||||
// 3. It is recommended to register the customized writer in the trainer setup,
 | 
			
		||||
//    If users don't register before calling launchAsyncDebugDump, then users
 | 
			
		||||
//    lose the chance to register (and the default writer will be
 | 
			
		||||
//    auto-registered).
 | 
			
		||||
class TORCH_API DebugInfoWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  DebugInfoWriter(int rank);
 | 
			
		||||
  virtual ~DebugInfoWriter();
 | 
			
		||||
  virtual void write(const std::string& ncclTrace);
 | 
			
		||||
  static DebugInfoWriter& getWriter(int rank);
 | 
			
		||||
  static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  DebugInfoWriter(std::string namePrefix, int rank) {
 | 
			
		||||
    filename_ = c10::str(namePrefix, rank);
 | 
			
		||||
  }
 | 
			
		||||
  std::string filename_;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  static std::unique_ptr<DebugInfoWriter> writer_;
 | 
			
		||||
  static std::atomic<bool> hasWriterRegistered_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// RAII wrapper for NCCL communicator
 | 
			
		||||
@ -267,6 +283,18 @@ class NCCLComm {
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
 | 
			
		||||
  std::unordered_map<std::string, std::string> ncclCommDump() {
 | 
			
		||||
    std::unordered_map<std::string, std::string> dump;
 | 
			
		||||
    if (isAborted()) {
 | 
			
		||||
      LOG(INFO) << "Communicator was aborted before trying to dump its state.";
 | 
			
		||||
      return dump;
 | 
			
		||||
    }
 | 
			
		||||
    C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt);
 | 
			
		||||
    return dump;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  ncclUniqueId getNcclId() {
 | 
			
		||||
    return ncclId_;
 | 
			
		||||
  }
 | 
			
		||||
@ -322,6 +350,9 @@ class NCCLComm {
 | 
			
		||||
    // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
 | 
			
		||||
    // timeout)
 | 
			
		||||
    commFailureReason_ = commFailureReason;
 | 
			
		||||
    LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
 | 
			
		||||
              << (commFailureReason ? *commFailureReason
 | 
			
		||||
                                    : "No abort reason provided.");
 | 
			
		||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
 | 
			
		||||
    C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
 | 
			
		||||
#else
 | 
			
		||||
@ -421,6 +452,8 @@ class NCCLComm {
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  friend class ProcessGroupNCCL;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  ncclComm_t ncclComm_;
 | 
			
		||||
  // Unique nccl_id for this communicator.
 | 
			
		||||
 | 
			
		||||
@ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
 | 
			
		||||
  return store_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
 | 
			
		||||
  c10::intrusive_ptr<Store> store = store_;
 | 
			
		||||
 | 
			
		||||
  while (store) {
 | 
			
		||||
    // Attempt to dynamically cast to PrefixStore
 | 
			
		||||
    PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
 | 
			
		||||
    if (asPrefixStore) {
 | 
			
		||||
      store = asPrefixStore->getUnderlyingStore();
 | 
			
		||||
    } else {
 | 
			
		||||
      break; // We've reached a non-PrefixStore
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
 | 
			
		||||
  return store;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10d
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store {
 | 
			
		||||
 | 
			
		||||
  c10::intrusive_ptr<Store> getUnderlyingStore();
 | 
			
		||||
 | 
			
		||||
  // Recursively to fetch the store before layers of wrapping with PrefixStore.
 | 
			
		||||
  c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  std::string prefix_;
 | 
			
		||||
  c10::intrusive_ptr<Store> store_;
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,7 +1,15 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#if defined(__linux__)
 | 
			
		||||
#include <fcntl.h>
 | 
			
		||||
#include <sys/stat.h>
 | 
			
		||||
#include <sys/types.h>
 | 
			
		||||
#include <unistd.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <future>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
@ -12,6 +20,7 @@
 | 
			
		||||
 | 
			
		||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/Store.hpp>
 | 
			
		||||
 | 
			
		||||
#include <ATen/DynamicLibrary.h>
 | 
			
		||||
@ -26,52 +35,74 @@
 | 
			
		||||
#include <torch/custom_class.h>
 | 
			
		||||
 | 
			
		||||
namespace c10d {
 | 
			
		||||
// Environment variable which controls whether we perform a NCCL healt check
 | 
			
		||||
// Control whether we perform a NCCL health check or not
 | 
			
		||||
// which ensures communicators are healthy at the beginning of init.
 | 
			
		||||
static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = {
 | 
			
		||||
    "TORCH_ENABLE_NCCL_HEALTH_CHECK",
 | 
			
		||||
    "ENABLE_NCCL_HEALTH_CHECK"};
 | 
			
		||||
 | 
			
		||||
// Environment variable which controls whether or not wait() is blocking or
 | 
			
		||||
// non-blocking.
 | 
			
		||||
// Control whether or not wait() is blocking or non-blocking.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
 | 
			
		||||
    "TORCH_NCCL_BLOCKING_WAIT",
 | 
			
		||||
    "NCCL_BLOCKING_WAIT"};
 | 
			
		||||
 | 
			
		||||
// Environment variable which controls whether or not we perform Async Error
 | 
			
		||||
// Handling with NCCL.
 | 
			
		||||
// Control whether or not we perform Async Error Handling with NCCL.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
 | 
			
		||||
    "TORCH_NCCL_ASYNC_ERROR_HANDLING",
 | 
			
		||||
    "NCCL_ASYNC_ERROR_HANDLING"};
 | 
			
		||||
 | 
			
		||||
// Environment Variable to control whether dumping debug info on watchdog
 | 
			
		||||
// Control whether dumping debug info on watchdog
 | 
			
		||||
// timeout is enabled. This variable must be set together with
 | 
			
		||||
// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
 | 
			
		||||
    "TORCH_NCCL_DUMP_ON_TIMEOUT"};
 | 
			
		||||
 | 
			
		||||
// Environment Variable to control whether Desync Debug is enabled.
 | 
			
		||||
// This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
 | 
			
		||||
// Control whether Desync Debug is enabled. This variable must be set
 | 
			
		||||
// together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
 | 
			
		||||
    "TORCH_NCCL_DESYNC_DEBUG",
 | 
			
		||||
    "NCCL_DESYNC_DEBUG"};
 | 
			
		||||
 | 
			
		||||
// Enable recording start-events for all ProcessGroupNCCL collectives, and
 | 
			
		||||
// compute accurate collective timing per-collective. (Note: end-events are
 | 
			
		||||
// recorded by default. Turn on this flag can increase chances of a watchdog
 | 
			
		||||
// hang due to performing a CUDA event query which eventually calls
 | 
			
		||||
// cudaEventElapsedTime() API.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
 | 
			
		||||
    "TORCH_NCCL_ENABLE_TIMING",
 | 
			
		||||
    "NCCL_ENABLE_TIMING"};
 | 
			
		||||
 | 
			
		||||
// Enable monitoring thread which aborts the process when the ProcessGroupNCCL
 | 
			
		||||
// Watchdog thread gets stuck and no heartbeat is detected after
 | 
			
		||||
// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
 | 
			
		||||
// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
 | 
			
		||||
// time than necessary tying up cluster resources.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
 | 
			
		||||
    "TORCH_NCCL_ENABLE_MONITORING"};
 | 
			
		||||
 | 
			
		||||
// Control the watchdog heartbeat timeout period after which the monitoring
 | 
			
		||||
// thread will abort the process.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
 | 
			
		||||
    "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
 | 
			
		||||
 | 
			
		||||
// The maximum number of events we store in the flight recorder's ring buffer.
 | 
			
		||||
// (One event could be the start or end of a collective, for example).
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
 | 
			
		||||
    "TORCH_NCCL_TRACE_BUFFER_SIZE"};
 | 
			
		||||
 | 
			
		||||
// Control how much extra time we will wait for dumping the debugging info
 | 
			
		||||
// before we exit and throws timeout exception.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
 | 
			
		||||
    "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
 | 
			
		||||
 | 
			
		||||
// Control the interval inside the watchdog thread to check the coordinated
 | 
			
		||||
// signal from other ranks, e.g. to dump the debugging information.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
 | 
			
		||||
    "TORCH_NCCL_COORD_CHECK_MILSEC"};
 | 
			
		||||
 | 
			
		||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
 | 
			
		||||
 | 
			
		||||
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
 | 
			
		||||
constexpr const char* EXCEPTION_DUMP = "exception_dump";
 | 
			
		||||
 | 
			
		||||
constexpr auto kProcessGroupNCCLDefaultTimeout =
 | 
			
		||||
    std::chrono::milliseconds(10 * 60 * 1000);
 | 
			
		||||
@ -110,6 +141,59 @@ static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
 | 
			
		||||
    {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
 | 
			
		||||
     "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
 | 
			
		||||
 | 
			
		||||
#if defined(__linux__)
 | 
			
		||||
struct DumpPipe {
 | 
			
		||||
  DumpPipe(int rank) {
 | 
			
		||||
    std::string fileStem =
 | 
			
		||||
        getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
 | 
			
		||||
    if (fileStem.empty() ||
 | 
			
		||||
        getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
 | 
			
		||||
    std::string filename = c10::str(fileStem, rank, ".pipe");
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        unlink(filename.c_str()) != -1 || errno == ENOENT,
 | 
			
		||||
        "Error removing existing named pipe ",
 | 
			
		||||
        filename);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        mkfifo(filename.c_str(), 0666) != -1,
 | 
			
		||||
        "Error creating named pipe ",
 | 
			
		||||
        filename);
 | 
			
		||||
    fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
 | 
			
		||||
    LOG(INFO) << "Pipe file " << filename
 | 
			
		||||
              << " has been opened, write to it to trigger NCCL Debug Dump.";
 | 
			
		||||
    TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
 | 
			
		||||
  }
 | 
			
		||||
  bool shouldDump() {
 | 
			
		||||
    if (fd_ == -1) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    char buf[128];
 | 
			
		||||
    // non-blocking from O_NONBLOCK above.
 | 
			
		||||
    // Ignore EINTR because we already will poll this
 | 
			
		||||
    // again later.
 | 
			
		||||
    ssize_t bytesRead = read(fd_, &buf, 128);
 | 
			
		||||
    return bytesRead > 0;
 | 
			
		||||
  }
 | 
			
		||||
  ~DumpPipe() {
 | 
			
		||||
    if (fd_ != -1) {
 | 
			
		||||
      close(fd_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  int fd_ = -1;
 | 
			
		||||
};
 | 
			
		||||
#else
 | 
			
		||||
struct DumpPipe {
 | 
			
		||||
  DumpPipe(int rank) {}
 | 
			
		||||
  bool shouldDump() {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// ProcessGroupNCCL implements NCCL bindings for c10d.
 | 
			
		||||
//
 | 
			
		||||
// All functions of the class are expected to be called in the same order
 | 
			
		||||
@ -205,6 +289,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
    uint64_t getSequencenumber() const override;
 | 
			
		||||
 | 
			
		||||
    const std::string& logPrefix() const;
 | 
			
		||||
 | 
			
		||||
    // Helper function that sets an exception_ptr on the WorkNCCL object.
 | 
			
		||||
    void setException(std::exception_ptr exception_ptr);
 | 
			
		||||
 | 
			
		||||
@ -540,8 +626,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  void enableCollectivesTiming() override;
 | 
			
		||||
 | 
			
		||||
  // Provide an API for users to define their own ways to store NCCL debug info.
 | 
			
		||||
  void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
 | 
			
		||||
  // Helper function for iteratively aborting communicators in the provided map
 | 
			
		||||
  void abortCommsFromMap(
 | 
			
		||||
      std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
 | 
			
		||||
          ncclCommsMap,
 | 
			
		||||
      c10::optional<std::string> abortReason);
 | 
			
		||||
 | 
			
		||||
  // Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
 | 
			
		||||
  // instead of relying on ProcessGroupNCCL destructor.
 | 
			
		||||
@ -694,6 +783,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Desync debug helper
 | 
			
		||||
  void logWorkEnd(WorkNCCL& work);
 | 
			
		||||
 | 
			
		||||
  // Generates a prefix that is unique to this process group and rank, for
 | 
			
		||||
  // disambiguating logs
 | 
			
		||||
  std::string createLogPrefix() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the unique prefix created in createLogPrefix
 | 
			
		||||
  const std::string& logPrefix() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the global rank of the device. This function assumes that users
 | 
			
		||||
  // always create a default global process group(PG) which includes all
 | 
			
		||||
  // devices. It is called in the constructor of ProcessGroupNCCL, so it always
 | 
			
		||||
  // return the rank_ of the the very first PG created, aka, default global PG.
 | 
			
		||||
  const int& globalRank() const;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Function that runs as part of a separate thread aside from watchdog
 | 
			
		||||
  // thread because we need to check the heartbeat from watchdog thread
 | 
			
		||||
@ -712,6 +814,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // for dump completion.
 | 
			
		||||
  std::future<bool> launchAsyncDebugDump();
 | 
			
		||||
 | 
			
		||||
  // Helper to wait up to the specified timeout and then abandon the dump.
 | 
			
		||||
  // Logs on timeout, and asserts the future's status is as expected.
 | 
			
		||||
  void waitForDumpOrTimeout(
 | 
			
		||||
      std::future<bool>& fut,
 | 
			
		||||
      const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
 | 
			
		||||
      size_t timeout_sec = 30);
 | 
			
		||||
 | 
			
		||||
  // A helper function to wait for a future to complete or timeout.
 | 
			
		||||
  void waitForFutureOrTimeout(
 | 
			
		||||
      std::future<bool>& fut,
 | 
			
		||||
      const std::chrono::milliseconds& timeOutMilSec,
 | 
			
		||||
      const std::string& futDescription);
 | 
			
		||||
 | 
			
		||||
  // When watchdog timeout, this function will be called and return debug info
 | 
			
		||||
  // for users. For now we only get information from retrieveDesyncReport.
 | 
			
		||||
  // We are working on enabling more useful debug information for watchdog
 | 
			
		||||
@ -720,9 +835,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  static const int64_t kWatchdogThreadSleepMillis;
 | 
			
		||||
 | 
			
		||||
  // The store is used to broadcast the NCCL unique ID of rank 0.
 | 
			
		||||
  // The store is used to broadcast the NCCL unique ID of rank 0. This store
 | 
			
		||||
  // comes with prefix and it is different across ProcessGroup NCCL instances
 | 
			
		||||
  // (aka, different ProcessGroups).
 | 
			
		||||
  c10::intrusive_ptr<Store> store_;
 | 
			
		||||
 | 
			
		||||
  // Reference to the store without prefix so that keys are same across all
 | 
			
		||||
  // ProcessGroup NCCL instances and (key, value) pairs written to the store are
 | 
			
		||||
  // global.
 | 
			
		||||
  c10::intrusive_ptr<Store> globalStore_;
 | 
			
		||||
 | 
			
		||||
  bool storeError_{false};
 | 
			
		||||
 | 
			
		||||
  const c10::intrusive_ptr<Options> options_;
 | 
			
		||||
@ -781,11 +903,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  std::mutex mutex_;
 | 
			
		||||
 | 
			
		||||
  // Heartbeat of watchdog thread.
 | 
			
		||||
  uint64_t heartbeat_;
 | 
			
		||||
  std::atomic_uint64_t heartbeat_;
 | 
			
		||||
 | 
			
		||||
  // The time interval used for deciding whether there is no watchdog heartbeat.
 | 
			
		||||
  int heartbeatTimeoutInSec_;
 | 
			
		||||
 | 
			
		||||
  // Extra time of sleep when waiting for timeout dump to finish.
 | 
			
		||||
  int waitTimeoutDumpInMilSec_;
 | 
			
		||||
 | 
			
		||||
  // Interval of check coordinated signals in ProcessGroupNCCL from other ranks
 | 
			
		||||
  // e.g., trigger the dump of the debugging info for timeout when notified.
 | 
			
		||||
  int coordCheckIntervalMilSec_;
 | 
			
		||||
 | 
			
		||||
  // Size of ring buffer where we store NCCL Traces for debugging.
 | 
			
		||||
  int ncclTraceBufferSize_;
 | 
			
		||||
 | 
			
		||||
@ -815,6 +944,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Whether there are hooks pending to be fired
 | 
			
		||||
  std::atomic<bool> hasPendingHooks_;
 | 
			
		||||
 | 
			
		||||
  // This is the signal from watchdog threads to indicate whether the monitor
 | 
			
		||||
  // thread should dump. Making it static so that it is accessiable from all the
 | 
			
		||||
  // PGs. With this flag, monitor thread would dump debug info under any one of
 | 
			
		||||
  // the 3 conditions: 1: this flag is set to true by the watchdog thread when
 | 
			
		||||
  // it detects a timeout. 2: timeout signal is received from
 | 
			
		||||
  // other ranks through tcpstore 3: no heartbeat of watchdog Note that only the
 | 
			
		||||
  // monitor thread from PG0 should dump the debug info and only once
 | 
			
		||||
  static std::atomic<bool> shouldDump_;
 | 
			
		||||
 | 
			
		||||
  // Mutex to Guard workMetaList_
 | 
			
		||||
  std::mutex workMetaListMutex_;
 | 
			
		||||
 | 
			
		||||
@ -823,9 +961,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  bool writeDebugInfo_ = false;
 | 
			
		||||
 | 
			
		||||
  // Mutex to Guard the check of writeDebugInfo_
 | 
			
		||||
  std::mutex writeDebugInfoMutex_;
 | 
			
		||||
 | 
			
		||||
  // Condition Variable for watchdog thread sleep
 | 
			
		||||
  std::condition_variable workMetaListCV_;
 | 
			
		||||
 | 
			
		||||
@ -902,8 +1037,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Whether or not to enable timeout root cause analysis.
 | 
			
		||||
  bool desyncDebug_;
 | 
			
		||||
 | 
			
		||||
  // Whether or not to dump debug info on timeout
 | 
			
		||||
  bool dumpOnTimeout_;
 | 
			
		||||
  // Whether or not to dump debug info on exception including both watchdog
 | 
			
		||||
  // timeout and nccl errors.
 | 
			
		||||
  bool dumpOnException_;
 | 
			
		||||
 | 
			
		||||
  // Whether or not to create start CUDAEvent and enable timing for start
 | 
			
		||||
  // and end events. Note that enableTiming_ is always true if desyncDebug_
 | 
			
		||||
@ -929,14 +1065,25 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  std::exception_ptr watchDogException_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  // The callback function to store NCCL debug info.
 | 
			
		||||
  std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  size_t uid_;
 | 
			
		||||
 | 
			
		||||
  std::string logPrefix_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_API std::string dump_nccl_trace();
 | 
			
		||||
 | 
			
		||||
// Gets a mutable reference to a global optional function.  Heartbeat Monitor
 | 
			
		||||
// will query this function and if available, call it to dump traces. Inside
 | 
			
		||||
// fbcode, we store a function here that uses an internal tool for process
 | 
			
		||||
// tracing
 | 
			
		||||
TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper();
 | 
			
		||||
 | 
			
		||||
// Similar to get_cpp_trace_dumper, this stores a function defined in
 | 
			
		||||
// torch-python layer that lets us check whether the GIL can be acquired,
 | 
			
		||||
// helpful for instrumenting in cases where a hang was observed.
 | 
			
		||||
typedef bool (*gil_checker_t)();
 | 
			
		||||
 | 
			
		||||
TORCH_API gil_checker_t& get_gil_checker();
 | 
			
		||||
} // namespace c10d
 | 
			
		||||
 | 
			
		||||
#endif // USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <system_error>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace c10d {
 | 
			
		||||
 | 
			
		||||
/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */
 | 
			
		||||
@ -269,10 +268,20 @@ inline std::string retrieveDesyncReport(
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter::DebugInfoWriter(int rank) {
 | 
			
		||||
  std::string fileName = getCvarString(
 | 
			
		||||
      {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
  filename_ = c10::str(fileName, rank);
 | 
			
		||||
/* Helper used by work::getDuration() and nccl flight recorder */
 | 
			
		||||
float getDurationFromFirstEvent(
 | 
			
		||||
    const std::vector<at::cuda::CUDAEvent>& ncclStartEvents,
 | 
			
		||||
    const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclStartEvents.size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup, but found multiple start events.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents.size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup, but found multiple end events.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents[0].query(),
 | 
			
		||||
      "getDuration can only be called after work is succeeded.")
 | 
			
		||||
  return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter::~DebugInfoWriter() = default;
 | 
			
		||||
@ -293,6 +302,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) {
 | 
			
		||||
  LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
 | 
			
		||||
  if (writer_ == nullptr) {
 | 
			
		||||
    std::string fileNamePrefix = getCvarString(
 | 
			
		||||
        {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
    // Using std::unique_ptr here to auto-delete the writer object
 | 
			
		||||
    // when the pointer itself is destroyed.
 | 
			
		||||
    std::unique_ptr<DebugInfoWriter> writerPtr(
 | 
			
		||||
        new DebugInfoWriter(fileNamePrefix, rank));
 | 
			
		||||
    DebugInfoWriter::registerWriter(std::move(writerPtr));
 | 
			
		||||
  }
 | 
			
		||||
  return *writer_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
 | 
			
		||||
  TORCH_CHECK_WITH(
 | 
			
		||||
      DistBackendError,
 | 
			
		||||
      hasWriterRegistered_.load() == false,
 | 
			
		||||
      "debugInfoWriter already registered");
 | 
			
		||||
  hasWriterRegistered_.store(true);
 | 
			
		||||
  writer_ = std::move(writer);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
 | 
			
		||||
std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
 | 
			
		||||
 | 
			
		||||
inline std::string pickle_str(const c10::IValue& v) {
 | 
			
		||||
  std::vector<char> result;
 | 
			
		||||
  {
 | 
			
		||||
@ -337,10 +371,10 @@ struct NCCLTraceBuffer {
 | 
			
		||||
                // update state information
 | 
			
		||||
    size_t pg_id_;
 | 
			
		||||
    size_t seq_id_; // as tracked by the process group
 | 
			
		||||
    const char* profiling_name_;
 | 
			
		||||
    std::string profiling_name_;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<torch::CapturedTraceback> traceback_;
 | 
			
		||||
    // we borrow pointser to start_ and end_ so we can query the state
 | 
			
		||||
    // we borrow pointers to start_ and end_ so we can query the state
 | 
			
		||||
    // on reporting. However, once the event is completed, the call
 | 
			
		||||
    // to `complete` will clear these.
 | 
			
		||||
    EventList *start_, *end_;
 | 
			
		||||
@ -348,8 +382,18 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    // timestamp when the entry was created, likely close to the time the work
 | 
			
		||||
    // was 'enqueued'- not necessarily started
 | 
			
		||||
    c10::time_t time_created_;
 | 
			
		||||
    c10::optional<float> duration_;
 | 
			
		||||
 | 
			
		||||
    const char* state_ = "scheduled";
 | 
			
		||||
    // timestamp when our CPU threads discovered that the kernel started.
 | 
			
		||||
    // will always be _after_ it actually started, and can be very late
 | 
			
		||||
    // if the watchdog thread got stuck on CUDA APIs.
 | 
			
		||||
    c10::optional<c10::time_t> time_discovered_started_;
 | 
			
		||||
 | 
			
		||||
    // timestamp when our CPU threads discovered that the kernel completed.
 | 
			
		||||
    // will always be _after_ it actually complated, and can be the same time
 | 
			
		||||
    // as the discovery of the start if the watchdog thread is stuck on CUDA
 | 
			
		||||
    // APIs
 | 
			
		||||
    c10::optional<c10::time_t> time_discovered_completed_;
 | 
			
		||||
 | 
			
		||||
    // size information for input/output tensors
 | 
			
		||||
    c10::SmallVector<int, 4> input_dims_;
 | 
			
		||||
@ -370,7 +414,7 @@ struct NCCLTraceBuffer {
 | 
			
		||||
  c10::optional<size_t> record(
 | 
			
		||||
      size_t pg_id,
 | 
			
		||||
      size_t seq_id,
 | 
			
		||||
      const char* profiling_name,
 | 
			
		||||
      std::string profiling_name,
 | 
			
		||||
      const std::vector<at::Tensor>& inputs,
 | 
			
		||||
      const std::vector<at::Tensor>& outputs,
 | 
			
		||||
      EventList* start,
 | 
			
		||||
@ -386,7 +430,7 @@ struct NCCLTraceBuffer {
 | 
			
		||||
        id_,
 | 
			
		||||
        pg_id,
 | 
			
		||||
        seq_id,
 | 
			
		||||
        profiling_name,
 | 
			
		||||
        std::move(profiling_name),
 | 
			
		||||
        std::move(traceback),
 | 
			
		||||
        std::move(start),
 | 
			
		||||
        std::move(end),
 | 
			
		||||
@ -424,8 +468,8 @@ struct NCCLTraceBuffer {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (started) {
 | 
			
		||||
        r.state_ = "started";
 | 
			
		||||
      if (started && !r.time_discovered_started_) {
 | 
			
		||||
        r.time_discovered_started_ = c10::getTime();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (r.end_ != nullptr) {
 | 
			
		||||
@ -436,8 +480,8 @@ struct NCCLTraceBuffer {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (completed) {
 | 
			
		||||
        r.state_ = "completed";
 | 
			
		||||
      if (completed && !r.time_discovered_completed_) {
 | 
			
		||||
        r.time_discovered_completed_ = c10::getTime();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -456,35 +500,96 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void retire_id(c10::optional<size_t> id) {
 | 
			
		||||
  /*
 | 
			
		||||
  Mark an Event as completed and free its events.
 | 
			
		||||
 | 
			
		||||
  This is called by the watchdog thread, and is asynchronous from the
 | 
			
		||||
  perspective of the main thread.
 | 
			
		||||
 | 
			
		||||
  compute_duration defaults to true since retire_id is only called in the
 | 
			
		||||
  watchdog thread, which is currently a place we call cuda APIs which may hang,
 | 
			
		||||
  but care should be taken to avoid computing duration in any function that must
 | 
			
		||||
  never hang. (timing must also be enabled for compute_duration - see
 | 
			
		||||
  TORCH_NCCL_ENABLE_TIMING).
 | 
			
		||||
  */
 | 
			
		||||
  void retire_id(c10::optional<size_t> id, bool compute_duration = true) {
 | 
			
		||||
    if (!enabled_ || !id) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    std::lock_guard<std::mutex> guard(mutex_);
 | 
			
		||||
    auto& entry = entries_.at(*id % max_entries_);
 | 
			
		||||
    if (entry.id_ == *id) {
 | 
			
		||||
      update_state(entry);
 | 
			
		||||
      entry.retired_ = true;
 | 
			
		||||
      entry.start_ = entry.end_ = nullptr;
 | 
			
		||||
 | 
			
		||||
    bool can_compute_duration = false;
 | 
			
		||||
    EventList* startEvents = nullptr;
 | 
			
		||||
    EventList* endEvents = nullptr;
 | 
			
		||||
    c10::optional<float> duration = c10::nullopt;
 | 
			
		||||
 | 
			
		||||
    std::unique_lock<std::mutex> guard(mutex_);
 | 
			
		||||
 | 
			
		||||
    Entry* entry = &entries_.at(*id % max_entries_);
 | 
			
		||||
    if (entry->id_ == *id) {
 | 
			
		||||
      update_state(*entry);
 | 
			
		||||
 | 
			
		||||
      if (compute_duration) {
 | 
			
		||||
        can_compute_duration = entry->time_discovered_completed_.has_value() &&
 | 
			
		||||
            entry->start_ && entry->end_;
 | 
			
		||||
        startEvents = entry->start_;
 | 
			
		||||
        endEvents = entry->end_;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (can_compute_duration) {
 | 
			
		||||
      // Compute duration without without holding the lock, because
 | 
			
		||||
      // cudaEventDuration() can hang, and we need to acquire the lock before we
 | 
			
		||||
      // can dump(), which we never want to block.
 | 
			
		||||
      guard.unlock();
 | 
			
		||||
      duration = getDurationFromFirstEvent(*startEvents, *endEvents);
 | 
			
		||||
      guard.lock();
 | 
			
		||||
 | 
			
		||||
      // Refresh the entry pointer, see if the entry has been overwritten
 | 
			
		||||
      entry = &entries_.at(*id % max_entries_);
 | 
			
		||||
      if (entry->id_ != *id) {
 | 
			
		||||
        LOG(INFO)
 | 
			
		||||
            << "retire_id abandoned for id " << *id
 | 
			
		||||
            << ", event was overwritten while waiting to compute duration.";
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      if (duration.has_value()) {
 | 
			
		||||
        entry->duration_ = duration.value();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    entry->retired_ = true;
 | 
			
		||||
    entry->start_ = entry->end_ = nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string dump() {
 | 
			
		||||
  std::string dump(
 | 
			
		||||
      const c10::optional<std::unordered_map<
 | 
			
		||||
          std::string,
 | 
			
		||||
          std::unordered_map<std::string, std::string>>>& ncclDumpMap) {
 | 
			
		||||
    auto result = dump_entries();
 | 
			
		||||
    auto entries = new_list();
 | 
			
		||||
    c10::IValue pg_id_s = "pg_id";
 | 
			
		||||
    c10::IValue seq_id_s = "seq_id";
 | 
			
		||||
    c10::IValue profiling_name_s = "profiling_name";
 | 
			
		||||
    c10::IValue input_sizes_s = "input_sizes";
 | 
			
		||||
    c10::IValue output_sizes_s = "output_sizes";
 | 
			
		||||
    c10::IValue time_created_s = "time_created_us";
 | 
			
		||||
    c10::IValue entries_key = "entries";
 | 
			
		||||
    c10::IValue nccl_comm_key = "nccl_comm_state";
 | 
			
		||||
    c10::IValue version_key = "version";
 | 
			
		||||
    // Update whenever changing contents or formatting of the dump
 | 
			
		||||
    // (minor when adding fields, major when changing existing fields)
 | 
			
		||||
    c10::IValue version_val = "1.1";
 | 
			
		||||
 | 
			
		||||
    c10::IValue frames_s = "frames";
 | 
			
		||||
    c10::IValue state_s = "state";
 | 
			
		||||
    c10::IValue line_s = "line";
 | 
			
		||||
    c10::IValue name_s = "name";
 | 
			
		||||
    c10::IValue filename_s = "filename";
 | 
			
		||||
    c10::IValue retired_s = "retired";
 | 
			
		||||
    c10::IValue pg_id_key = "pg_id";
 | 
			
		||||
    c10::IValue seq_id_key = "seq_id";
 | 
			
		||||
    c10::IValue profiling_name_key = "profiling_name";
 | 
			
		||||
    c10::IValue input_sizes_key = "input_sizes";
 | 
			
		||||
    c10::IValue output_sizes_key = "output_sizes";
 | 
			
		||||
    c10::IValue time_created_key = "time_created_ns";
 | 
			
		||||
    c10::IValue duration_key = "duration_ms";
 | 
			
		||||
 | 
			
		||||
    c10::IValue frames_key = "frames";
 | 
			
		||||
    c10::IValue state_key = "state";
 | 
			
		||||
    c10::IValue line_key = "line";
 | 
			
		||||
    c10::IValue name_key = "name";
 | 
			
		||||
    c10::IValue filename_key = "filename";
 | 
			
		||||
    c10::IValue retired_key = "retired";
 | 
			
		||||
    c10::IValue time_discovered_started_key = "time_discovered_started_ns";
 | 
			
		||||
    c10::IValue time_discovered_completed_key = "time_discovered_completed_ns";
 | 
			
		||||
 | 
			
		||||
    std::vector<torch::CapturedTraceback*> tracebacks;
 | 
			
		||||
    for (auto& e : result) {
 | 
			
		||||
@ -494,9 +599,9 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    std::vector<c10::IValue> all_frames;
 | 
			
		||||
    for (const auto& f : stracebacks.all_frames) {
 | 
			
		||||
      auto d = new_dict();
 | 
			
		||||
      d.insert(name_s, f.funcname);
 | 
			
		||||
      d.insert(filename_s, f.filename);
 | 
			
		||||
      d.insert(line_s, int64_t(f.lineno));
 | 
			
		||||
      d.insert(name_key, f.funcname);
 | 
			
		||||
      d.insert(filename_key, f.filename);
 | 
			
		||||
      d.insert(line_key, int64_t(f.lineno));
 | 
			
		||||
      all_frames.emplace_back(std::move(d));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -504,10 +609,13 @@ struct NCCLTraceBuffer {
 | 
			
		||||
      auto& e = result.at(i);
 | 
			
		||||
      auto& tb = stracebacks.tracebacks.at(i);
 | 
			
		||||
      auto dict = new_dict();
 | 
			
		||||
      dict.insert(pg_id_s, int64_t(e.pg_id_));
 | 
			
		||||
      dict.insert(seq_id_s, int64_t(e.seq_id_));
 | 
			
		||||
      dict.insert(profiling_name_s, e.profiling_name_);
 | 
			
		||||
      dict.insert(time_created_s, int64_t(e.time_created_ / 1000));
 | 
			
		||||
      dict.insert(pg_id_key, int64_t(e.pg_id_));
 | 
			
		||||
      dict.insert(seq_id_key, int64_t(e.seq_id_));
 | 
			
		||||
      dict.insert(profiling_name_key, e.profiling_name_);
 | 
			
		||||
      dict.insert(time_created_key, int64_t(e.time_created_));
 | 
			
		||||
      if (e.duration_) {
 | 
			
		||||
        dict.insert(duration_key, *e.duration_);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto it = e.sizes_.begin();
 | 
			
		||||
      auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
 | 
			
		||||
@ -523,19 +631,56 @@ struct NCCLTraceBuffer {
 | 
			
		||||
        return sizes;
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      dict.insert(input_sizes_s, read_sizes(e.input_dims_));
 | 
			
		||||
      dict.insert(output_sizes_s, read_sizes(e.output_dims_));
 | 
			
		||||
      dict.insert(state_s, e.state_);
 | 
			
		||||
      dict.insert(retired_s, e.retired_);
 | 
			
		||||
      dict.insert(input_sizes_key, read_sizes(e.input_dims_));
 | 
			
		||||
      dict.insert(output_sizes_key, read_sizes(e.output_dims_));
 | 
			
		||||
      if (e.time_discovered_completed_.has_value()) {
 | 
			
		||||
        dict.insert(state_key, "completed");
 | 
			
		||||
      } else if (e.time_discovered_started_.has_value()) {
 | 
			
		||||
        dict.insert(state_key, "started");
 | 
			
		||||
      } else {
 | 
			
		||||
        dict.insert(state_key, "scheduled");
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      dict.insert(
 | 
			
		||||
          time_discovered_started_key,
 | 
			
		||||
          e.time_discovered_started_.has_value()
 | 
			
		||||
              ? int64_t(*e.time_discovered_started_)
 | 
			
		||||
              : c10::IValue());
 | 
			
		||||
      dict.insert(
 | 
			
		||||
          time_discovered_completed_key,
 | 
			
		||||
          e.time_discovered_completed_.has_value()
 | 
			
		||||
              ? int64_t(*e.time_discovered_completed_)
 | 
			
		||||
              : c10::IValue());
 | 
			
		||||
      dict.insert(retired_key, e.retired_);
 | 
			
		||||
 | 
			
		||||
      auto frames = new_list();
 | 
			
		||||
      for (int64_t frame : tb) {
 | 
			
		||||
        frames.push_back(all_frames.at(frame));
 | 
			
		||||
      }
 | 
			
		||||
      dict.insert(frames_s, frames);
 | 
			
		||||
      dict.insert(frames_key, frames);
 | 
			
		||||
      entries.push_back(dict);
 | 
			
		||||
    }
 | 
			
		||||
    return pickle_str(entries);
 | 
			
		||||
 | 
			
		||||
    // convert ncclDumpMap into a dictionary
 | 
			
		||||
    auto per_comm_dict = new_dict();
 | 
			
		||||
    if (ncclDumpMap.has_value()) {
 | 
			
		||||
      for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
 | 
			
		||||
        auto inner_dict = new_dict();
 | 
			
		||||
        for (const auto& [key, value] : ncclDump) {
 | 
			
		||||
          inner_dict.insert(key, value);
 | 
			
		||||
        }
 | 
			
		||||
        per_comm_dict.insert(ncclId, inner_dict);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto dict = new_dict();
 | 
			
		||||
    dict.insert(entries_key, entries);
 | 
			
		||||
    dict.insert(version_key, version_val);
 | 
			
		||||
    if (per_comm_dict.size() > 0) {
 | 
			
		||||
      dict.insert(nccl_comm_key, per_comm_dict);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return pickle_str(dict);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,35 @@
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
bool acquire_gil() {
 | 
			
		||||
  // basically if this function can acquire the gil, it will return quickly.
 | 
			
		||||
  // if not, it will hang forever.  The idea is to call this from a thread
 | 
			
		||||
  // wrapped in a future, and then check the future after a timeout, to
 | 
			
		||||
  // determine whether we're facing gil contention.
 | 
			
		||||
  if (Py_IsInitialized()) {
 | 
			
		||||
    pybind11::gil_scoped_acquire gil;
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If we end up here, its probably still a "pass" from the perspective of
 | 
			
		||||
  // checking whether python is stuck. but currently we don't check the return
 | 
			
		||||
  // value of this function anyway, just check whether it returned quickly vs
 | 
			
		||||
  // timing out.  Taking a long time is the main sign of trouble.  Fast return
 | 
			
		||||
  // with true or with false is both OK from the perspective of debugging python
 | 
			
		||||
  // hangs.
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool registerGilChecker() {
 | 
			
		||||
  c10d::get_gil_checker() = &acquire_gil;
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool registered = registerGilChecker();
 | 
			
		||||
#endif // USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
// Wrapper to ensure GIL is released before destructing ProcessGroupGloo
 | 
			
		||||
// TODO: move this somewhere more generally useful
 | 
			
		||||
template <typename T>
 | 
			
		||||
@ -1033,6 +1062,29 @@ Example::
 | 
			
		||||
    >>> store.add("first_key", 6)
 | 
			
		||||
    >>> # Should return 7
 | 
			
		||||
    >>> store.get("first_key")
 | 
			
		||||
)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "check",
 | 
			
		||||
              &::c10d::Store::check,
 | 
			
		||||
              py::call_guard<py::gil_scoped_release>(),
 | 
			
		||||
              R"(
 | 
			
		||||
The call to check whether a given list of ``keys`` have value stored in
 | 
			
		||||
the store. This call immediately returns in normal cases but still suffers
 | 
			
		||||
from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
 | 
			
		||||
Calling :meth:`~torch.distributed.store.check` with a list of keys that
 | 
			
		||||
one wants to check whether stored in the store or not.
 | 
			
		||||
 | 
			
		||||
Arguments:
 | 
			
		||||
    keys (lisr[str]): The keys to query whether stored in the store.
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
    >>> import torch.distributed as dist
 | 
			
		||||
    >>> from datetime import timedelta
 | 
			
		||||
    >>> # Using TCPStore as an example, other store types can also be used
 | 
			
		||||
    >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
 | 
			
		||||
    >>> store.add("first_key", 1)
 | 
			
		||||
    >>> # Should return 7
 | 
			
		||||
    >>> store.check(["first_key"])
 | 
			
		||||
)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "delete_key",
 | 
			
		||||
@ -1404,7 +1456,11 @@ Arguments:
 | 
			
		||||
      .def_property_readonly(
 | 
			
		||||
          "underlying_store",
 | 
			
		||||
          &::c10d::PrefixStore::getUnderlyingStore,
 | 
			
		||||
          R"(Gets the underlying store object that PrefixStore wraps around.)");
 | 
			
		||||
          R"(Gets the underlying store object that PrefixStore wraps around.)")
 | 
			
		||||
      .def_property_readonly(
 | 
			
		||||
          "_underlying_non_prefix_store",
 | 
			
		||||
          &::c10d::PrefixStore::getUnderlyingNonPrefixStore,
 | 
			
		||||
          R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
 | 
			
		||||
 | 
			
		||||
  auto processGroup =
 | 
			
		||||
      py::class_<
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ def tail_logfile(
 | 
			
		||||
            return
 | 
			
		||||
        time.sleep(interval_sec)
 | 
			
		||||
 | 
			
		||||
    with open(file) as fp:
 | 
			
		||||
    with open(file, errors="replace") as fp:
 | 
			
		||||
        while True:
 | 
			
		||||
            line = fp.readline()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user