mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			39 Commits
		
	
	
		
			mlazos/use
			...
			whc_flight
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7c507b78c4 | |||
| 0019901601 | |||
| 18be18535b | |||
| 2729367313 | |||
| 33537aae24 | |||
| dcdb1337dd | |||
| 9cf0f2bd59 | |||
| 1d2e877c05 | |||
| f27b979b0c | |||
| f30d6047ad | |||
| 75311510ef | |||
| dbd6094d05 | |||
| 397b9d47e9 | |||
| 36a01a8ab9 | |||
| ee336cf58a | |||
| a9e2e745d7 | |||
| ab4df89eea | |||
| 9d02ebe876 | |||
| b61e01cce9 | |||
| f7ce61ba53 | |||
| e7bae15ab1 | |||
| e71b422908 | |||
| 389940ce60 | |||
| b2237a7c85 | |||
| ef5dfe3f3e | |||
| e303dc3c08 | |||
| 265efad2de | |||
| 60f0455905 | |||
| 4898313791 | |||
| f4da9adf6b | |||
| 8f7f35273e | |||
| 44ec9612ed | |||
| 4d3bea2b29 | |||
| 0bcdddc3c1 | |||
| 28b6220312 | |||
| 210b7b65e2 | |||
| 4da10b5cd3 | |||
| f09763814f | |||
| 80923ed5a6 | 
@ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) {
 | 
			
		||||
// Extend the nested class outside the parent class
 | 
			
		||||
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  TestDebugInfoWriter() : DebugInfoWriter(0) {}
 | 
			
		||||
  TestDebugInfoWriter(std::string namePrefix)
 | 
			
		||||
      : DebugInfoWriter(namePrefix, 0) {}
 | 
			
		||||
 | 
			
		||||
  void write(const std::string& ncclTrace) override {
 | 
			
		||||
    traces_.assign(ncclTrace.begin(), ncclTrace.end());
 | 
			
		||||
@ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
 | 
			
		||||
  // The storer here is very similar to the fallback storer.
 | 
			
		||||
  // The only difference is that we are storing traces also in memory for
 | 
			
		||||
  // validation.
 | 
			
		||||
  std::string fileNamePrefix = c10d::getCvarString(
 | 
			
		||||
      {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
  std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
 | 
			
		||||
      std::make_unique<TestDebugInfoWriter>();
 | 
			
		||||
      std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
 | 
			
		||||
  std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
 | 
			
		||||
  pg.registerDebugInfoWriter(std::move(wrterForTestPtr));
 | 
			
		||||
  c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
 | 
			
		||||
 | 
			
		||||
  // Normal collective case.
 | 
			
		||||
  auto work = pg.allreduce(tensors_);
 | 
			
		||||
@ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
 | 
			
		||||
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
 | 
			
		||||
 protected:
 | 
			
		||||
  void SetUp() override {
 | 
			
		||||
    // TODO (kwen2501)
 | 
			
		||||
    GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
 | 
			
		||||
                 << "will rewrite them after refactoring Work queues.";
 | 
			
		||||
    ProcessGroupNCCLErrorsTest::SetUp();
 | 
			
		||||
    std::string timeInterval = std::to_string(heartBeatIntervalInSec);
 | 
			
		||||
    ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
 | 
			
		||||
 | 
			
		||||
@ -3542,11 +3542,12 @@ class SparseCollective(MultiProcessTestCase):
 | 
			
		||||
class NCCLTraceTestBase(MultiProcessTestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'
 | 
			
		||||
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'  # see 'timing_enabled' parametrized tests
 | 
			
		||||
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
 | 
			
		||||
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
 | 
			
		||||
        self.tempdir = tempfile.TemporaryDirectory()
 | 
			
		||||
        os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
 | 
			
		||||
        os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
 | 
			
		||||
        self._spawn_processes()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -3617,17 +3618,27 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_short(self):
 | 
			
		||||
    @parametrize("timing_enabled", [True, False])
 | 
			
		||||
    def test_short(self, timing_enabled):
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            return
 | 
			
		||||
        pg = self._create_process_group_nccl()
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            pg._enable_collectives_timing()
 | 
			
		||||
        device = self.local_device
 | 
			
		||||
        a = torch.full((3, 4), float(self.rank), device=device)
 | 
			
		||||
        for i in range(2):
 | 
			
		||||
            f = pg.allreduce(a)
 | 
			
		||||
        f.wait()
 | 
			
		||||
        torch.cuda.synchronize(device=device)
 | 
			
		||||
 | 
			
		||||
        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
 | 
			
		||||
        time.sleep(1)
 | 
			
		||||
 | 
			
		||||
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
        ver = t['version']
 | 
			
		||||
        self.assertEqual(ver, "1.0")
 | 
			
		||||
        t = t['entries']
 | 
			
		||||
        self.assertEqual(len(t), 2)
 | 
			
		||||
        last = t[-1]
 | 
			
		||||
        self.assertEqual(last['state'], 'completed')
 | 
			
		||||
@ -3636,9 +3647,14 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
        self.assertEqual(last['output_sizes'], ((3, 4),))
 | 
			
		||||
        self.assertEqual(last['seq_id'], 2)
 | 
			
		||||
        now = datetime.now()
 | 
			
		||||
        event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000)
 | 
			
		||||
        event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000)
 | 
			
		||||
        before_test = now - timedelta(minutes=1)
 | 
			
		||||
        self.assertTrue(before_test < event_created_time < now)
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            # very loose bounds, measured 0.036 ms on devgpu
 | 
			
		||||
            self.assertTrue(0 < last['duration_ms'] < 100)
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertTrue("duration_ms" not in last)
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
@ -3698,6 +3714,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
        f.wait()
 | 
			
		||||
        torch.cuda.synchronize(device=device)
 | 
			
		||||
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
        t = t['entries']
 | 
			
		||||
        self.assertEqual(len(t), 10)
 | 
			
		||||
        first = t[0]
 | 
			
		||||
        last = t[-1]
 | 
			
		||||
@ -3732,6 +3749,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
                pg.allreduce(a).wait()
 | 
			
		||||
            e.synchronize()
 | 
			
		||||
            t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
            t = t['entries']
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                self.assertEqual(t[-1]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[-1]['state'], 'completed')
 | 
			
		||||
@ -3773,6 +3791,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
                # give the other thread some time to fill the cuda buffer
 | 
			
		||||
                time.sleep(5)
 | 
			
		||||
                t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                if self.rank == 0:
 | 
			
		||||
                    self.assertEqual(t[-1]['seq_id'], 1)
 | 
			
		||||
                    self.assertEqual(t[-1]['state'], 'completed')
 | 
			
		||||
@ -3832,6 +3851,9 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    @parametrize("timing_enabled", [True, False])
 | 
			
		||||
    def test_timeout_dumps(self, timing_enabled):
 | 
			
		||||
        # We need to completely disable the coordinated timeout dump to avoid rank 0
 | 
			
		||||
        # also timeout so that we set the check frequency to be very large (25 min).
 | 
			
		||||
        os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1500000'
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for rank0 to crash before looking for its output file
 | 
			
		||||
@ -3839,6 +3861,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            self.assertEqual(self._wait_process(0, timeout=90), -6)
 | 
			
		||||
            with open(self._trace_name(rank=0), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 2)
 | 
			
		||||
                self.assertEqual(t[0]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[0]['state'], 'completed')
 | 
			
		||||
@ -3868,7 +3891,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
 | 
			
		||||
instantiate_parametrized_tests(NCCLTraceTest)
 | 
			
		||||
 | 
			
		||||
class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
    def _check_return_codes(self, elapsed_time):
 | 
			
		||||
        # the base test infra assumes processes exit with matching return codes,
 | 
			
		||||
        # but we want rank0 to abort and rank1 to exit cleanly in this test
 | 
			
		||||
@ -3877,7 +3900,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_timeout_dumps_on_idle_ranks(self):
 | 
			
		||||
    def test_timeout_dumps_on_stuck_ranks(self):
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for both rank0 and 1 to crash before looking for both ranks' output
 | 
			
		||||
@ -3888,20 +3911,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
 | 
			
		||||
            with open(self._trace_name(rank=0), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 2)
 | 
			
		||||
            with open(self._trace_name(rank=1), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 1)
 | 
			
		||||
                self.assertEqual(t[0]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[0]['state'], 'completed')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Set heartbeat timeout to a shorter one (default timeout is 2 min).
 | 
			
		||||
        os.environ[
 | 
			
		||||
            "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"
 | 
			
		||||
        ] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}"
 | 
			
		||||
        pg = self._create_process_group_nccl()
 | 
			
		||||
 | 
			
		||||
        device = self.local_device
 | 
			
		||||
        with torch.cuda.device(device):
 | 
			
		||||
            a = torch.full((3, 4), float(self.rank), device=device)
 | 
			
		||||
@ -3910,12 +3930,14 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                pg.allreduce(a).wait()
 | 
			
		||||
 | 
			
		||||
            # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
 | 
			
		||||
            # rank 0 will get stuck, timeout and then signal a timeout to all ranks.
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
            # Force rank 1 to idle so that it also gets debug info dump triggered.
 | 
			
		||||
            if self.rank == 1:
 | 
			
		||||
                time.sleep(6)
 | 
			
		||||
                # Force rank 1 to idle so that it will eventually timeout as well after
 | 
			
		||||
                # getting the global signal to dump the debugging info.
 | 
			
		||||
                time.sleep(600)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    assert (
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ class StoreTestBase:
 | 
			
		||||
    def _create_store(self, i):
 | 
			
		||||
        raise RuntimeError("not implemented")
 | 
			
		||||
 | 
			
		||||
    def _test_set_get(self, fs):
 | 
			
		||||
    def _test_set_get_check(self, fs):
 | 
			
		||||
        fs.add("key", 1)
 | 
			
		||||
        fs.add("key", 2)
 | 
			
		||||
        fs.add("key", 3)
 | 
			
		||||
@ -90,14 +90,16 @@ class StoreTestBase:
 | 
			
		||||
        self.assertEqual(b"value1", fs.get("key1"))
 | 
			
		||||
        self.assertEqual(b"value2", fs.get("key2"))
 | 
			
		||||
        self.assertEqual(b"21", fs.get("key3"))
 | 
			
		||||
        self.assertTrue(fs.check(["key3"]))
 | 
			
		||||
        self.assertFalse(fs.check(["Randomkey3"]))
 | 
			
		||||
 | 
			
		||||
        fs.set("-key3", "7")
 | 
			
		||||
        self.assertEqual(b"7", fs.get("-key3"))
 | 
			
		||||
        fs.delete_key("-key3")
 | 
			
		||||
        self.assertEqual(fs.num_keys(), self.num_keys_total)
 | 
			
		||||
 | 
			
		||||
    def test_set_get(self):
 | 
			
		||||
        self._test_set_get(self._create_store())
 | 
			
		||||
    def test_set_get_check(self):
 | 
			
		||||
        self._test_set_get_check(self._create_store())
 | 
			
		||||
 | 
			
		||||
    def _test_compare_set(self, store):
 | 
			
		||||
        missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
 | 
			
		||||
@ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
 | 
			
		||||
    def num_keys_total(self):
 | 
			
		||||
        return 6
 | 
			
		||||
 | 
			
		||||
    def test_underlying_non_prefix_store(self):
 | 
			
		||||
        store = self._create_store()
 | 
			
		||||
        wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
 | 
			
		||||
        self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
 | 
			
		||||
        self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
 | 
			
		||||
 | 
			
		||||
class MyPythonStore(dist.Store):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@ -175,14 +175,29 @@ std::string getNcclErrorDetailStr(
 | 
			
		||||
    c10::optional<std::string> processGroupFailureReason = c10::nullopt);
 | 
			
		||||
 | 
			
		||||
// Write NCCL debug info to local disk or any storage users define.
 | 
			
		||||
// There are some constrains we set for the debug info writer:
 | 
			
		||||
// 1. The writer should only be registered once.
 | 
			
		||||
// 2. Once registered, users cannot change it including un-register.
 | 
			
		||||
// 3. It is recommended to register the customized writer in the trainer setup,
 | 
			
		||||
//    If users don't register before calling launchAsyncDebugDump, then users
 | 
			
		||||
//    lose the chance to register (and the default writer will be
 | 
			
		||||
//    auto-registered).
 | 
			
		||||
class TORCH_API DebugInfoWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  DebugInfoWriter(int rank);
 | 
			
		||||
  virtual ~DebugInfoWriter();
 | 
			
		||||
  virtual void write(const std::string& ncclTrace);
 | 
			
		||||
  static DebugInfoWriter& getWriter(int rank);
 | 
			
		||||
  static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  DebugInfoWriter(std::string namePrefix, int rank) {
 | 
			
		||||
    filename_ = c10::str(namePrefix, rank);
 | 
			
		||||
  }
 | 
			
		||||
  std::string filename_;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  static std::unique_ptr<DebugInfoWriter> writer_;
 | 
			
		||||
  static std::atomic<bool> hasWriterRegistered_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// RAII wrapper for NCCL communicator
 | 
			
		||||
 | 
			
		||||
@ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
 | 
			
		||||
  return store_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
 | 
			
		||||
  c10::intrusive_ptr<Store> store = store_;
 | 
			
		||||
 | 
			
		||||
  while (store) {
 | 
			
		||||
    // Attempt to dynamically cast to PrefixStore
 | 
			
		||||
    PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
 | 
			
		||||
    if (asPrefixStore) {
 | 
			
		||||
      store = asPrefixStore->getUnderlyingStore();
 | 
			
		||||
    } else {
 | 
			
		||||
      break; // We've reached a non-PrefixStore
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
 | 
			
		||||
  return store;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10d
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store {
 | 
			
		||||
 | 
			
		||||
  c10::intrusive_ptr<Store> getUnderlyingStore();
 | 
			
		||||
 | 
			
		||||
  // Recursively to fetch the store before layers of wrapping with PrefixStore.
 | 
			
		||||
  c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  std::string prefix_;
 | 
			
		||||
  c10::intrusive_ptr<Store> store_;
 | 
			
		||||
 | 
			
		||||
@ -341,6 +341,35 @@ std::string dump_nccl_trace() {
 | 
			
		||||
  return NCCLTraceBuffer::get()->dump();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::optional<std::function<std::string()>>& get_cpp_trace_dumper() {
 | 
			
		||||
  static c10::optional<std::function<std::string()>> dumper(c10::nullopt);
 | 
			
		||||
  return dumper;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
gil_checker_t& get_gil_checker() {
 | 
			
		||||
  static gil_checker_t gil_checker = nullptr;
 | 
			
		||||
  return gil_checker;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::future<bool> launchAsyncGilCheck() {
 | 
			
		||||
  std::promise<bool> resultPromise;
 | 
			
		||||
  std::future<bool> resultFuture = resultPromise.get_future();
 | 
			
		||||
  TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker");
 | 
			
		||||
  std::thread workerThread([promise = std::move(resultPromise)]() mutable {
 | 
			
		||||
    try {
 | 
			
		||||
      auto& gil_checker = get_gil_checker();
 | 
			
		||||
      promise.set_value((*gil_checker)());
 | 
			
		||||
    } catch (...) {
 | 
			
		||||
      promise.set_exception(std::current_exception());
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  // Detach the thread to allow it to run independently
 | 
			
		||||
  workerThread.detach();
 | 
			
		||||
 | 
			
		||||
  return resultFuture;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return CUDA device with ordinal given by input rank.  If we aren't
 | 
			
		||||
// bound to a specific device, there is no strict guarantee that this
 | 
			
		||||
// heuristic is the correct assignment of ranks to GPUs that Python
 | 
			
		||||
@ -358,7 +387,7 @@ at::Device ProcessGroupNCCL::guessDeviceForRank() const {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
 | 
			
		||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100;
 | 
			
		||||
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
 | 
			
		||||
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
 | 
			
		||||
 | 
			
		||||
@ -466,12 +495,17 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
 | 
			
		||||
  std::unique_lock<std::mutex> lock(mutex_);
 | 
			
		||||
  exception_ = exception_ptr;
 | 
			
		||||
  if (exception_) {
 | 
			
		||||
    LOG(INFO) << "[Rank " << rank_ << "]"
 | 
			
		||||
              << " found async exception when checking for NCCL errors: "
 | 
			
		||||
    LOG(INFO) << logPrefix()
 | 
			
		||||
              << "found async exception when checking for NCCL errors: "
 | 
			
		||||
              << getExceptionMsgFromExceptionPtr(exception_);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const {
 | 
			
		||||
  static std::string prefix = c10::str("[Rank ", rank_, "] ");
 | 
			
		||||
  return prefix;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::WorkNCCL::setException(
 | 
			
		||||
    std::exception_ptr exception_ptr) {
 | 
			
		||||
  std::unique_lock<std::mutex> lock(mutex_);
 | 
			
		||||
@ -527,9 +561,7 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
 | 
			
		||||
    return true;
 | 
			
		||||
 | 
			
		||||
  std::string exceptionMsg = c10::str(
 | 
			
		||||
      "[Rank ",
 | 
			
		||||
      rank_,
 | 
			
		||||
      "] ",
 | 
			
		||||
      logPrefix(),
 | 
			
		||||
      "Watchdog caught collective operation timeout: ",
 | 
			
		||||
      *this,
 | 
			
		||||
      " ran for ",
 | 
			
		||||
@ -550,13 +582,13 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
 | 
			
		||||
        "Some NCCL operations have failed or timed out. Due to the ",
 | 
			
		||||
        "asynchronous nature of CUDA kernels, subsequent GPU operations ",
 | 
			
		||||
        "might run on corrupted/incomplete data.");
 | 
			
		||||
    LOG(ERROR) << exceptionMsg;
 | 
			
		||||
    LOG(ERROR) << logPrefix() << exceptionMsg;
 | 
			
		||||
    C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException");
 | 
			
		||||
 | 
			
		||||
    if (SHOULD_TEAR_DOWN(errorHandling)) {
 | 
			
		||||
      auto tearDownMsg = c10::str(
 | 
			
		||||
          "To avoid data inconsistency, we are taking the entire process down.");
 | 
			
		||||
      LOG(ERROR) << tearDownMsg;
 | 
			
		||||
      LOG(ERROR) << logPrefix() << tearDownMsg;
 | 
			
		||||
      std::rethrow_exception(exception_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -597,9 +629,8 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
 | 
			
		||||
      // can not run new events successfully.
 | 
			
		||||
      if (timedOut) {
 | 
			
		||||
        std::string exceptionMsg = c10::str(
 | 
			
		||||
            "[Rank ",
 | 
			
		||||
            rank_,
 | 
			
		||||
            "] Work ",
 | 
			
		||||
            logPrefix(),
 | 
			
		||||
            "Work ",
 | 
			
		||||
            (*this),
 | 
			
		||||
            " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1).");
 | 
			
		||||
        LOG(ERROR) << exceptionMsg;
 | 
			
		||||
@ -713,6 +744,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
      ValueError,
 | 
			
		||||
      at::cuda::getNumGPUs() != 0,
 | 
			
		||||
      "ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
 | 
			
		||||
  logPrefix_ = createLogPrefix();
 | 
			
		||||
  blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
 | 
			
		||||
  asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
 | 
			
		||||
      getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
 | 
			
		||||
@ -723,8 +755,18 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
  heartbeat_ = 1ULL;
 | 
			
		||||
  monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
 | 
			
		||||
  heartbeatTimeoutInSec_ =
 | 
			
		||||
      getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 2 /*2 Mins*/);
 | 
			
		||||
      getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/);
 | 
			
		||||
  waitTimeoutDumpInMilSec_ =
 | 
			
		||||
      getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 2000);
 | 
			
		||||
  coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
 | 
			
		||||
  ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
 | 
			
		||||
  // store_ usually is wrapped with PrefixStore and the prefix is different
 | 
			
		||||
  // across different ProcessGroupNCCL(PG) instances. We need to get the
 | 
			
		||||
  // underlying non-PrefixStore for sharing global information shared across
 | 
			
		||||
  // different PGs.
 | 
			
		||||
  PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
 | 
			
		||||
  globalStore_ =
 | 
			
		||||
      prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
 | 
			
		||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
 | 
			
		||||
  enableTiming_.store(
 | 
			
		||||
      getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
 | 
			
		||||
@ -737,15 +779,15 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
          expandable_segments()) {
 | 
			
		||||
    useTensorRegisterAllocatorHook_ = false;
 | 
			
		||||
    LOG(INFO)
 | 
			
		||||
        << "[Rank " << rank_
 | 
			
		||||
        << "] disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
 | 
			
		||||
        << logPrefix()
 | 
			
		||||
        << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode.";
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  if (blockingWait_) {
 | 
			
		||||
    if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
 | 
			
		||||
      LOG(INFO)
 | 
			
		||||
          << "[Rank " << rank_ << "] TORCH_NCCL_BLOCKING_WAIT and "
 | 
			
		||||
          << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and "
 | 
			
		||||
          << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG"
 | 
			
		||||
          << "should not both be enabled. "
 | 
			
		||||
          << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process.";
 | 
			
		||||
@ -755,8 +797,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
  } else {
 | 
			
		||||
    if (desyncDebug_ && asyncErrorHandling_ == NoHandling) {
 | 
			
		||||
      LOG(INFO)
 | 
			
		||||
          << "[Rank " << rank_
 | 
			
		||||
          << "] TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
 | 
			
		||||
          << logPrefix()
 | 
			
		||||
          << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
 | 
			
		||||
          << "must both be enabled. "
 | 
			
		||||
          << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING.";
 | 
			
		||||
      asyncErrorHandling_ = SkipCleanUp;
 | 
			
		||||
@ -781,12 +823,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
  const std::string OFF = "OFF";
 | 
			
		||||
  std::string torch_distributed_debug =
 | 
			
		||||
      getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
 | 
			
		||||
  std::string nccl_debug = getCvarString({"NCCL_DEBUG"}, OFF.c_str());
 | 
			
		||||
  LOG(INFO) << "[Rank " << rank_
 | 
			
		||||
            << "] ProcessGroupNCCL initialization options: "
 | 
			
		||||
            << "NCCL version: " << getNcclVersion()
 | 
			
		||||
  LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
 | 
			
		||||
            << "NCCL version: " << getNcclVersion() << ", size: " << size
 | 
			
		||||
            << ", global rank: " << globalRank()
 | 
			
		||||
            << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
 | 
			
		||||
            << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeout_
 | 
			
		||||
            << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
 | 
			
		||||
            << waitTimeoutDumpInMilSec_
 | 
			
		||||
            << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
 | 
			
		||||
            << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
 | 
			
		||||
            << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
 | 
			
		||||
@ -804,7 +847,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
            << monitorThreadEnabled_.load()
 | 
			
		||||
            << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
 | 
			
		||||
            << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
 | 
			
		||||
            << ", NCCL_DEBUG: " << nccl_debug << ", ID=" << this->getID();
 | 
			
		||||
            << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
 | 
			
		||||
            << ", ID=" << this->getID();
 | 
			
		||||
 | 
			
		||||
  RECORD_PARAM_COMMS(
 | 
			
		||||
      0, // seq
 | 
			
		||||
@ -835,7 +879,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
 | 
			
		||||
void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
 | 
			
		||||
  std::vector<at::Device> rankDevices = {device};
 | 
			
		||||
  const auto key = getKeyFromDevices(rankDevices);
 | 
			
		||||
  LOG(INFO) << "Eagerly connecting nccl backend with device " << device;
 | 
			
		||||
  LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device "
 | 
			
		||||
            << device;
 | 
			
		||||
  getNCCLComm(key, rankDevices, OpType::ALLREDUCE);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -846,8 +891,8 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
 | 
			
		||||
#ifdef NCCL_HAS_COMM_SPLIT
 | 
			
		||||
  std::vector<at::Device> rankDevices = {device};
 | 
			
		||||
  const auto key = getKeyFromDevices(rankDevices);
 | 
			
		||||
  LOG(INFO) << "Performing nocolor split on backend device " << device
 | 
			
		||||
            << ", key " << key << ", i am " << this;
 | 
			
		||||
  LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
 | 
			
		||||
            << device << ", key " << key << ", i am " << this;
 | 
			
		||||
  auto comm = getNCCLComm(key, rankDevices, OpType::ALLREDUCE);
 | 
			
		||||
  TORCH_CHECK_WITH(
 | 
			
		||||
      DistBackendError,
 | 
			
		||||
@ -897,8 +942,7 @@ void ProcessGroupNCCL::runHealthCheck() {
 | 
			
		||||
  // We don't need to join the thread, just need to verify health check via the
 | 
			
		||||
  // CV. Hence we detach the thread here.
 | 
			
		||||
  t.detach(); // NOLINT
 | 
			
		||||
  LOG(INFO) << "[Rank " << rank_ << "]"
 | 
			
		||||
            << " will wait up to " << options_->timeout.count()
 | 
			
		||||
  LOG(INFO) << logPrefix() << "will wait up to " << options_->timeout.count()
 | 
			
		||||
            << " msec for NCCL health check to complete.";
 | 
			
		||||
  std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
 | 
			
		||||
  healthCheckData.healthCheckCv.wait_for(
 | 
			
		||||
@ -1002,10 +1046,49 @@ std::future<bool> ProcessGroupNCCL::launchAsyncDebugDump() {
 | 
			
		||||
  return resultFuture;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void abortCommsFromMap(
 | 
			
		||||
std::chrono::time_point<std::chrono::steady_clock> getWakeupTime(
 | 
			
		||||
    int intervalInMilSec) {
 | 
			
		||||
  return std::chrono::steady_clock::now() +
 | 
			
		||||
      std::chrono::milliseconds(intervalInMilSec);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::waitForDumpOrTimeout(
 | 
			
		||||
    std::future<bool>& fut,
 | 
			
		||||
    const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
 | 
			
		||||
    size_t timeout_sec) {
 | 
			
		||||
  TORCH_CHECK(fut.valid(), "Expected a valid future");
 | 
			
		||||
 | 
			
		||||
  auto futStatus = fut.wait_for(std::chrono::seconds(timeout_sec));
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      futStatus != std::future_status::deferred, "Expected eager launch.");
 | 
			
		||||
  if (futStatus == std::future_status::ready) {
 | 
			
		||||
    // Calling .get() will re-raise any exception from the future, and we don't
 | 
			
		||||
    // care about the retval
 | 
			
		||||
    try {
 | 
			
		||||
      fut.get();
 | 
			
		||||
      std::this_thread::sleep_until(wakeUpTime);
 | 
			
		||||
    } catch (const std::exception& e) {
 | 
			
		||||
      LOG(ERROR) << logPrefix()
 | 
			
		||||
                 << "Caught exception during async debug dump: \"" << e.what()
 | 
			
		||||
                 << "\"\n";
 | 
			
		||||
    } catch (...) {
 | 
			
		||||
      LOG(ERROR) << logPrefix()
 | 
			
		||||
                 << "Caught unknown exception during async debug dump.";
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    LOG(INFO)
 | 
			
		||||
        << logPrefix() << "Debug dump timed out and is being abandoned."
 | 
			
		||||
        << " This may be due to slow ADDR2LINE performance processing stacktraces."
 | 
			
		||||
        << " Try TORCH_DISABLE_ADDR2LINE=1 and TORCH_NCCL_TRACE_CPP_STACK=0 to work around.";
 | 
			
		||||
  }
 | 
			
		||||
  // Ensure we sleep at least until wakeUpTime regardless of future execution
 | 
			
		||||
  // time
 | 
			
		||||
  std::this_thread::sleep_until(wakeUpTime);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::abortCommsFromMap(
 | 
			
		||||
    std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
 | 
			
		||||
        ncclCommsMap,
 | 
			
		||||
    const int rank,
 | 
			
		||||
    c10::optional<std::string> abortReason) {
 | 
			
		||||
  // The process may control multiple devices, loop through the communicators on
 | 
			
		||||
  // each device
 | 
			
		||||
@ -1026,8 +1109,17 @@ void abortCommsFromMap(
 | 
			
		||||
    // their responsibility to destroy the process group and recreate
 | 
			
		||||
    // it to recover from errors.
 | 
			
		||||
 | 
			
		||||
    LOG(INFO) << "[Rank " << rank << "] Destroyed " << ncclComms.size()
 | 
			
		||||
              << "communicators on CUDA device " << devName;
 | 
			
		||||
    c10::StreamId streamId = -1;
 | 
			
		||||
    if (ncclStreams_.find(devName) != ncclStreams_.end()) {
 | 
			
		||||
      auto streams = ncclStreams_.at(devName);
 | 
			
		||||
      if (streams.size() > 0) {
 | 
			
		||||
        streamId = streams[0].id();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG(INFO) << logPrefix() << "] Destroyed " << ncclComms.size()
 | 
			
		||||
              << "communicators on CUDA device: " << devName
 | 
			
		||||
              << " with stream: " << streamId;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1048,8 +1140,8 @@ void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) {
 | 
			
		||||
  ncclCommDevIdxMapMutex.unlock();
 | 
			
		||||
 | 
			
		||||
  std::lock_guard<std::mutex> lock(mutex_);
 | 
			
		||||
  abortCommsFromMap(devNCCLCommMap_, rank_, abortReason);
 | 
			
		||||
  abortCommsFromMap(inInitializationCommMap_, rank_, abortReason);
 | 
			
		||||
  abortCommsFromMap(devNCCLCommMap_, abortReason);
 | 
			
		||||
  abortCommsFromMap(inInitializationCommMap_, abortReason);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::shutdown() {
 | 
			
		||||
@ -1067,6 +1159,7 @@ void ProcessGroupNCCL::shutdown() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ProcessGroupNCCL::~ProcessGroupNCCL() {
 | 
			
		||||
  LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
 | 
			
		||||
  terminateProcessGroup_.store(true);
 | 
			
		||||
  workMetaListCV_.notify_one();
 | 
			
		||||
 | 
			
		||||
@ -1074,6 +1167,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
 | 
			
		||||
  if (ncclCommWatchdogThread_.joinable()) {
 | 
			
		||||
    ncclCommWatchdogThread_.join();
 | 
			
		||||
  }
 | 
			
		||||
  LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  if (onCompletionHookThread_.joinable())
 | 
			
		||||
@ -1083,6 +1177,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
 | 
			
		||||
  // threads dying due to aborted communicator and raising a SIGABRT
 | 
			
		||||
  std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
 | 
			
		||||
  abort(abortReason);
 | 
			
		||||
  LOG(INFO) << logPrefix() << "ProcessGroupNCCL abort finished.";
 | 
			
		||||
 | 
			
		||||
  // We need to wait for abort to finish before we can safely shut down
 | 
			
		||||
  // heartbeat monitoring thread.
 | 
			
		||||
@ -1095,33 +1190,20 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::registerDebugInfoWriter(
 | 
			
		||||
    std::unique_ptr<DebugInfoWriter> writer) {
 | 
			
		||||
  TORCH_CHECK_WITH(
 | 
			
		||||
      DistBackendError,
 | 
			
		||||
      debugInfoWriter_ == nullptr,
 | 
			
		||||
      "ProcessGroupNCCL debugInfoWriter already registered");
 | 
			
		||||
  debugInfoWriter_ = std::move(writer);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ProcessGroupNCCL::dumpDebuggingInfo() {
 | 
			
		||||
  // Serialize all calls to this function to avoid corrupting data, but allow
 | 
			
		||||
  // multiple calls in one runtime. User is responsible for preserving the
 | 
			
		||||
  // output file from an earlier call before a later call overwrites it.
 | 
			
		||||
  std::lock_guard<std::mutex> lock(writeDebugInfoMutex_);
 | 
			
		||||
  LOG(ERROR) << "ProcessGroupNCCL preparing to dump debug info.";
 | 
			
		||||
  static std::mutex writeDebugInfoMutex;
 | 
			
		||||
  std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
 | 
			
		||||
  LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info.";
 | 
			
		||||
  if (ncclTraceBufferSize_ > 0) {
 | 
			
		||||
    // We dump nccl trace into local disk by default and users can register
 | 
			
		||||
    // their customized writer by inheriting `DebugInfoWriter` via
 | 
			
		||||
    // `registerDebugInfoWriter`.
 | 
			
		||||
    auto ncclTrace = dump_nccl_trace();
 | 
			
		||||
    if (debugInfoWriter_ == nullptr) {
 | 
			
		||||
      // Dump the trace blob into local disk as a fallback.
 | 
			
		||||
      std::unique_ptr<DebugInfoWriter> debugInfoWriterPtr =
 | 
			
		||||
          std::make_unique<DebugInfoWriter>(rank_);
 | 
			
		||||
      registerDebugInfoWriter(std::move(debugInfoWriterPtr));
 | 
			
		||||
    }
 | 
			
		||||
    debugInfoWriter_->write(ncclTrace);
 | 
			
		||||
    DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank());
 | 
			
		||||
    writer.write(ncclTrace);
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
@ -1130,48 +1212,133 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() {
 | 
			
		||||
void ProcessGroupNCCL::terminateProcess(std::string errMsg) {
 | 
			
		||||
  // Logging with `FATAL`, after errMsg printed, it calls `std::abort()`
 | 
			
		||||
  // to terminate the program execution.
 | 
			
		||||
  LOG(FATAL) << errMsg;
 | 
			
		||||
  LOG(FATAL) << logPrefix() << errMsg;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int computeDeltaMS(
 | 
			
		||||
    std::chrono::time_point<std::chrono::steady_clock> start,
 | 
			
		||||
    std::chrono::time_point<std::chrono::steady_clock> end) {
 | 
			
		||||
  return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
 | 
			
		||||
      .count();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::heartbeatMonitor() {
 | 
			
		||||
  uint64_t heartBeatCounter = 0ULL;
 | 
			
		||||
  std::string errorMsg;
 | 
			
		||||
  std::string exitMsg;
 | 
			
		||||
  bool checkTimeoutSignal = (dumpOnTimeout_ && uid_ == 0);
 | 
			
		||||
  int monitorPollInterval = checkTimeoutSignal ? coordCheckIntervalMilSec_
 | 
			
		||||
                                               : heartbeatTimeoutInSec_ * 1000;
 | 
			
		||||
  auto lastTimePollStore = std::chrono::steady_clock::now();
 | 
			
		||||
  auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
 | 
			
		||||
  std::future<bool> asyncDebugDump;
 | 
			
		||||
  while (true) {
 | 
			
		||||
    // This won't have any lock since this lock is only used here.
 | 
			
		||||
    // Please be aware that mutex `monitorMutex_` should not be used
 | 
			
		||||
    // somewhere else to avoid the deadlock.
 | 
			
		||||
    std::unique_lock<std::mutex> lock(monitorMutex_);
 | 
			
		||||
    if (monitorWakeUpCV_.wait_for(
 | 
			
		||||
            lock, std::chrono::seconds(heartbeatTimeoutInSec_), [&] {
 | 
			
		||||
            lock, std::chrono::milliseconds(monitorPollInterval), [&] {
 | 
			
		||||
              return terminateHeartbeatMonitorThread_.load();
 | 
			
		||||
            })) {
 | 
			
		||||
      // For the normal complete or user interception, monitorWakeUpCV_
 | 
			
		||||
      // will get notified, we early return and exit heartbeatMonitor.
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    auto currentTime = std::chrono::steady_clock::now();
 | 
			
		||||
 | 
			
		||||
    // Check the heart beat of watchdog thread.
 | 
			
		||||
    auto heartbeat = heartbeat_;
 | 
			
		||||
    if (heartbeat != heartBeatCounter) {
 | 
			
		||||
      heartBeatCounter = heartbeat;
 | 
			
		||||
    } else {
 | 
			
		||||
      // No heartbeat increase detected and timeout.
 | 
			
		||||
      break;
 | 
			
		||||
    // We put extra functionality in the thread for the default PG (aka, uid_=0)
 | 
			
		||||
    // because the signal is same across different PGs. We only need to run
 | 
			
		||||
    // once per process to avoid duplicate things performed in too many separate
 | 
			
		||||
    // threads. For example, we check a global flag on the TCPStore periodically
 | 
			
		||||
    // to see if any PG on any rank observed a timeout and signaled peers to
 | 
			
		||||
    // dump debugging info, and we avoid hammering the TCPStore from all PGs on
 | 
			
		||||
    // the same rank.
 | 
			
		||||
    if (checkTimeoutSignal) {
 | 
			
		||||
      // We poll store to see if some ranks have flagged a timeout when
 | 
			
		||||
      // we haven't polled for `heartbeat_timeout` seconds and there haven't
 | 
			
		||||
      // any work added or removed for `watchdog_timeout` seconds.
 | 
			
		||||
      if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >=
 | 
			
		||||
              kWatchdogThreadSleepMillis &&
 | 
			
		||||
          computeDeltaMS(lastTimePollStore, currentTime) >=
 | 
			
		||||
              coordCheckIntervalMilSec_) {
 | 
			
		||||
        lastTimePollStore = currentTime;
 | 
			
		||||
        if (globalStore_->check({std::string(TIMEOUT_DUMP)})) {
 | 
			
		||||
          errorMsg = c10::str(
 | 
			
		||||
              logPrefix(),
 | 
			
		||||
              "Received a global timeout from another rank and will ",
 | 
			
		||||
              "start to dump the debug info.");
 | 
			
		||||
          exitMsg = c10::str(
 | 
			
		||||
              "ProcessGroupNCCL's watchdog detected a collective timeout and notified current rank. ",
 | 
			
		||||
              "This is most likely caused by incorrect usages of collectives, e.g., wrong ",
 | 
			
		||||
              "sizes used across ranks, the order of collectives is not same for all ranks ",
 | 
			
		||||
              "or the scheduled collective, for some reason, didn't run. Additionally, ",
 | 
			
		||||
              "this can be caused by GIL deadlock or other reasons such as network errors or ",
 | 
			
		||||
              "bugs in the communications library (e.g. NCCL), etc. We tried our best to ",
 | 
			
		||||
              "dump the debug info into the storage to help you debug the issue.");
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
 | 
			
		||||
        heartbeatTimeoutInSec_ * 1000) {
 | 
			
		||||
      // Check the heart beat of watchdog thread.
 | 
			
		||||
      lastTimeHeartBeatCheck = currentTime;
 | 
			
		||||
      auto heartbeat = heartbeat_.load();
 | 
			
		||||
      if (heartbeat != heartBeatCounter) {
 | 
			
		||||
        heartBeatCounter = heartbeat;
 | 
			
		||||
      } else {
 | 
			
		||||
        // No heartbeat increase detected and timeout.
 | 
			
		||||
        errorMsg = c10::str(
 | 
			
		||||
            logPrefix(),
 | 
			
		||||
            "Heartbeat monitor timed out! Process will be terminated after dumping debug info.",
 | 
			
		||||
            " workMetaList_.size()=",
 | 
			
		||||
            workMetaList_.size());
 | 
			
		||||
        exitMsg = c10::str(
 | 
			
		||||
            "ProcessGroupNCCL's watchdog got stuck for ",
 | 
			
		||||
            heartbeatTimeoutInSec_,
 | 
			
		||||
            " seconds without making progress in monitoring enqueued collectives. ",
 | 
			
		||||
            "This typically indicates a NCCL/CUDA API hang blocking the watchdog, ",
 | 
			
		||||
            "and could be triggered by another thread holding the GIL inside a ",
 | 
			
		||||
            "CUDA api, or other deadlock-prone behaviors.",
 | 
			
		||||
            "If you suspect the watchdog is not actually stuck and a longer timeout would help, ",
 | 
			
		||||
            "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value "
 | 
			
		||||
            "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)."
 | 
			
		||||
            "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout "
 | 
			
		||||
            "or false positive abort; otherwise, please attempt to debug the hang.");
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  LOG(ERROR) << errorMsg;
 | 
			
		||||
 | 
			
		||||
  auto& cpp_dumper = get_cpp_trace_dumper();
 | 
			
		||||
  if (cpp_dumper.has_value()) {
 | 
			
		||||
    LOG(INFO) << "Dumping c++ stacktraces: " << cpp_dumper.value()();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
 | 
			
		||||
  // Store debug info to storage if no other thread does it. (By default to
 | 
			
		||||
  // local disk)
 | 
			
		||||
  std::future<bool> asyncDebugDump = launchAsyncDebugDump();
 | 
			
		||||
  asyncDebugDump = launchAsyncDebugDump();
 | 
			
		||||
 | 
			
		||||
  // Create a error message reported from MonitorThread, so
 | 
			
		||||
  // we throw exception and make the whole process to be killed.
 | 
			
		||||
  const auto exitMsg = c10::str(
 | 
			
		||||
      "[Rank ",
 | 
			
		||||
      rank_,
 | 
			
		||||
      "] NCCL monitor thread timeout. Basically, this could ",
 | 
			
		||||
      "be due to CUDA or NCCL calls being unexpectedly blocking, ",
 | 
			
		||||
      "especially when your program enters a deadlock state in watchdog "
 | 
			
		||||
      "or destructors. If you see this error, please file a bug to PyTorch.");
 | 
			
		||||
  if (get_gil_checker() != nullptr) {
 | 
			
		||||
    auto fut = launchAsyncGilCheck();
 | 
			
		||||
    auto kGilCheckTimeout = std::chrono::milliseconds(300);
 | 
			
		||||
    auto futStatus = fut.wait_for(kGilCheckTimeout);
 | 
			
		||||
    if (futStatus != std::future_status::ready) {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
          futStatus != std::future_status::deferred,
 | 
			
		||||
          "Expected the future to have been launched eagerly.");
 | 
			
		||||
      LOG(ERROR)
 | 
			
		||||
          << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang";
 | 
			
		||||
    }
 | 
			
		||||
    LOG(INFO) << "Could acquire GIL on exit";
 | 
			
		||||
  } else {
 | 
			
		||||
    LOG(INFO)
 | 
			
		||||
        << "GIL checker was not registered, perhaps this is a no-python build?";
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // There are two possible cases for the watchdog thread exit:
 | 
			
		||||
  // Case one: desync report runs quickly, and it follows the step:
 | 
			
		||||
@ -1197,46 +1364,40 @@ void ProcessGroupNCCL::heartbeatMonitor() {
 | 
			
		||||
  // We already log completion inside the thread, so it may not be necessary to
 | 
			
		||||
  // check the return value here.  We mainly use a future so we can exit early
 | 
			
		||||
  // if done.
 | 
			
		||||
  asyncDebugDump.wait_for(std::chrono::seconds(heartbeatTimeoutInSec_));
 | 
			
		||||
  waitForDumpOrTimeout(asyncDebugDump, wakeUpTime);
 | 
			
		||||
 | 
			
		||||
  if (!terminateHeartbeatMonitorThread_.load()) {
 | 
			
		||||
    const auto logMsg = c10::str(
 | 
			
		||||
        "[Rank ",
 | 
			
		||||
        rank_,
 | 
			
		||||
        "] monitoring thread detects no heartbeat and will finally kill the process!",
 | 
			
		||||
        " terminateProcessGroup_",
 | 
			
		||||
        terminateProcessGroup_,
 | 
			
		||||
        " collectiveDebugInfoMode_",
 | 
			
		||||
        collectiveDebugInfoMode_);
 | 
			
		||||
    LOG(ERROR) << logMsg;
 | 
			
		||||
    terminateProcess(exitMsg);
 | 
			
		||||
    // Create a error message reported from MonitorThread, so
 | 
			
		||||
    // we throw exception and make the whole process to be killed.
 | 
			
		||||
    // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
 | 
			
		||||
    // url here.
 | 
			
		||||
    const auto finalExitMsg = c10::str(logPrefix(), exitMsg);
 | 
			
		||||
    terminateProcess(finalExitMsg);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::ncclCommWatchdog() {
 | 
			
		||||
  try {
 | 
			
		||||
    VLOG(2) << "[Rank " << rank_ << "] NCCL watchdog thread started!";
 | 
			
		||||
    VLOG(2) << logPrefix() << "NCCL watchdog thread started!";
 | 
			
		||||
    if (monitorThreadEnabled_.load()) {
 | 
			
		||||
      ncclHeartbeatMonitorThread_ =
 | 
			
		||||
          std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
 | 
			
		||||
    }
 | 
			
		||||
    watchdogHandler();
 | 
			
		||||
    VLOG(2) << "[Rank " << rank_
 | 
			
		||||
            << "] NCCL watchdog thread terminated normally";
 | 
			
		||||
    VLOG(2) << logPrefix() << "NCCL watchdog thread terminated normally";
 | 
			
		||||
  } catch (std::exception& e) {
 | 
			
		||||
    if (std::string(e.what()).find("driver shutting down") !=
 | 
			
		||||
        std::string::npos) {
 | 
			
		||||
      LOG(INFO)
 | 
			
		||||
          << "[Rank " << rank_
 | 
			
		||||
          << "] main process destroyed cuda before watchdog loop exited, terminating watchdog."
 | 
			
		||||
          << logPrefix()
 | 
			
		||||
          << "main process destroyed cuda before watchdog loop exited, terminating watchdog."
 | 
			
		||||
          << " (Watchdog caught exception: " << e.what();
 | 
			
		||||
 | 
			
		||||
    } else {
 | 
			
		||||
      // Append error message reported from watchdogHandler
 | 
			
		||||
      const auto exitMsg = c10::str(
 | 
			
		||||
          "[Rank ",
 | 
			
		||||
          rank_,
 | 
			
		||||
          "] NCCL watchdog thread terminated with exception: ",
 | 
			
		||||
          logPrefix(),
 | 
			
		||||
          "NCCL watchdog thread terminated with exception: ",
 | 
			
		||||
          e.what());
 | 
			
		||||
      LOG(ERROR) << exitMsg;
 | 
			
		||||
      // TODO(whc) clean up the rethrow - why is it stored in a class var and
 | 
			
		||||
@ -1247,9 +1408,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
 | 
			
		||||
    }
 | 
			
		||||
  } catch (...) {
 | 
			
		||||
    const auto exitMsg = c10::str(
 | 
			
		||||
        "[Rank ",
 | 
			
		||||
        rank_,
 | 
			
		||||
        "] NCCL watchdog thread terminated with exception: unknown");
 | 
			
		||||
        logPrefix(), "NCCL watchdog thread terminated with exception: unknown");
 | 
			
		||||
    LOG(ERROR) << exitMsg;
 | 
			
		||||
    watchDogException_ =
 | 
			
		||||
        std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg));
 | 
			
		||||
@ -1288,12 +1447,13 @@ std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() {
 | 
			
		||||
 | 
			
		||||
#if defined(__linux__)
 | 
			
		||||
struct DumpPipe {
 | 
			
		||||
  DumpPipe(bool enabled, int rank) {
 | 
			
		||||
    if (!enabled) {
 | 
			
		||||
  DumpPipe(int rank) {
 | 
			
		||||
    std::string fileStem =
 | 
			
		||||
        getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
 | 
			
		||||
    if (fileStem.empty() ||
 | 
			
		||||
        getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    std::string fileStem = getCvarString(
 | 
			
		||||
        {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
    TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
 | 
			
		||||
    std::string filename = c10::str(fileStem, rank, ".pipe");
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
@ -1305,6 +1465,8 @@ struct DumpPipe {
 | 
			
		||||
        "Error creating named pipe ",
 | 
			
		||||
        filename);
 | 
			
		||||
    fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
 | 
			
		||||
    LOG(INFO) << "Pipe file " << filename
 | 
			
		||||
              << " has been opened, write to it to trigger NCCL Debug Dump.";
 | 
			
		||||
    TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
 | 
			
		||||
  }
 | 
			
		||||
  bool shouldDump() {
 | 
			
		||||
@ -1329,22 +1491,40 @@ struct DumpPipe {
 | 
			
		||||
};
 | 
			
		||||
#else
 | 
			
		||||
struct DumpPipe {
 | 
			
		||||
  DumpPipe(bool enabled, int rank) {}
 | 
			
		||||
  DumpPipe(int rank) {}
 | 
			
		||||
  bool shouldDump() {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::string ProcessGroupNCCL::createLogPrefix() const {
 | 
			
		||||
  return c10::str("[PG ", uid_, " Rank ", rank_, "] ");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const std::string& ProcessGroupNCCL::logPrefix() const {
 | 
			
		||||
  return logPrefix_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const int& ProcessGroupNCCL::globalRank() const {
 | 
			
		||||
  static int globalRank = rank_;
 | 
			
		||||
  return globalRank;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroupNCCL::watchdogHandler() {
 | 
			
		||||
  bool done = false;
 | 
			
		||||
  lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
 | 
			
		||||
  auto lastTimePollStore = std::chrono::steady_clock::now();
 | 
			
		||||
  c10::optional<std::future<bool>> optAsyncDebugDump;
 | 
			
		||||
 | 
			
		||||
  std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
 | 
			
		||||
 | 
			
		||||
  DumpPipe dumpPipe(dumpOnTimeout_, rank_);
 | 
			
		||||
  c10::optional<DumpPipe> dumpPipe = c10::nullopt;
 | 
			
		||||
  if (uid_ == 0) {
 | 
			
		||||
    // DumpPipe is one per-trainer process, and its convenient to name them
 | 
			
		||||
    // after 'global' ranks in the system, So we assume processgroup (uid)==0 is
 | 
			
		||||
    // the global PG and has globally unique rank ids across trainers.
 | 
			
		||||
    dumpPipe.emplace(rank_);
 | 
			
		||||
  }
 | 
			
		||||
  while (!done || !terminateProcessGroup_.load()) {
 | 
			
		||||
    std::unique_lock<std::mutex> lock(workMetaListMutex_);
 | 
			
		||||
    // We busy-poll the work vector every kWatchdogThreadSleepMillis
 | 
			
		||||
@ -1356,37 +1536,6 @@ void ProcessGroupNCCL::watchdogHandler() {
 | 
			
		||||
    // Bump up heart beat by one.
 | 
			
		||||
    heartbeat_++;
 | 
			
		||||
 | 
			
		||||
    // poll store to see if some ranks have flagged a timeout when
 | 
			
		||||
    // we haven't polled for `heartbeat_timeout` seconds and there haven't
 | 
			
		||||
    // any work added or removed for `watchdog_timeout` seconds.
 | 
			
		||||
    if (dumpOnTimeout_) {
 | 
			
		||||
      auto currentTime = std::chrono::steady_clock::now();
 | 
			
		||||
      auto timeSinceLastWorkListUpdate =
 | 
			
		||||
          std::chrono::duration_cast<std::chrono::milliseconds>(
 | 
			
		||||
              (currentTime - lastWorkListUpdateTime_))
 | 
			
		||||
              .count();
 | 
			
		||||
      auto timeSinceLastPollStore =
 | 
			
		||||
          std::chrono::duration_cast<std::chrono::milliseconds>(
 | 
			
		||||
              (currentTime - lastTimePollStore))
 | 
			
		||||
              .count();
 | 
			
		||||
      if (timeSinceLastWorkListUpdate >= kWatchdogThreadSleepMillis &&
 | 
			
		||||
          timeSinceLastPollStore >= heartbeatTimeoutInSec_ * 1000) {
 | 
			
		||||
        lastTimePollStore = currentTime;
 | 
			
		||||
        if (store_->check({std::string(TIMEOUT_DUMP)}) && !optAsyncDebugDump) {
 | 
			
		||||
          optAsyncDebugDump = launchAsyncDebugDump();
 | 
			
		||||
          optAsyncDebugDump->wait_for(
 | 
			
		||||
              std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30));
 | 
			
		||||
          const auto exitMsg = c10::str(
 | 
			
		||||
              "Some other rank's watchdog thread detected a timeout and notified ",
 | 
			
		||||
              "all other ranks, so we're dumping debug info and aborting [Rank ",
 | 
			
		||||
              rank_,
 | 
			
		||||
              "] as well.");
 | 
			
		||||
          LOG(ERROR) << exitMsg;
 | 
			
		||||
          C10_THROW_ERROR(DistBackendError, exitMsg);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (auto it = workMetaList_.begin(); it != workMetaList_.end();
 | 
			
		||||
         /* no increment */) {
 | 
			
		||||
      auto& work = *it;
 | 
			
		||||
@ -1411,9 +1560,10 @@ void ProcessGroupNCCL::watchdogHandler() {
 | 
			
		||||
              // abort process immediately.
 | 
			
		||||
              collectiveDebugInfoMode_.store(true);
 | 
			
		||||
              std::vector<uint8_t> vec(1);
 | 
			
		||||
              store_->set(std::string(TIMEOUT_DUMP), vec);
 | 
			
		||||
              globalStore_->set(std::string(TIMEOUT_DUMP), vec);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
 | 
			
		||||
            if (dumpOnTimeout_ && !optAsyncDebugDump) {
 | 
			
		||||
              // Store debug info to storage. (By default to local disk)
 | 
			
		||||
              optAsyncDebugDump = launchAsyncDebugDump();
 | 
			
		||||
@ -1421,20 +1571,21 @@ void ProcessGroupNCCL::watchdogHandler() {
 | 
			
		||||
 | 
			
		||||
            if (desyncDebug_) {
 | 
			
		||||
              auto desyncMsg = getNCCLWatchdogDebugInfo();
 | 
			
		||||
              LOG(ERROR) << desyncMsg;
 | 
			
		||||
              LOG(ERROR) << logPrefix() << desyncMsg;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (dumpOnTimeout_) {
 | 
			
		||||
              // Store debug info to storage. (By default to local disk)
 | 
			
		||||
              optAsyncDebugDump->wait_for(
 | 
			
		||||
                  std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30));
 | 
			
		||||
              waitForDumpOrTimeout(*optAsyncDebugDump, wakeUpTime);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          } catch (const std::exception& e) {
 | 
			
		||||
            LOG(ERROR) << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
 | 
			
		||||
            LOG(ERROR) << logPrefix()
 | 
			
		||||
                       << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
 | 
			
		||||
                       << " Please file an issue. Error: " << e.what();
 | 
			
		||||
          } catch (...) {
 | 
			
		||||
            LOG(ERROR)
 | 
			
		||||
                << logPrefix()
 | 
			
		||||
                << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
 | 
			
		||||
                << " Please file an issue.";
 | 
			
		||||
          }
 | 
			
		||||
@ -1455,7 +1606,7 @@ void ProcessGroupNCCL::watchdogHandler() {
 | 
			
		||||
 | 
			
		||||
      // Clean up completed work
 | 
			
		||||
      if (work.isCompleted()) {
 | 
			
		||||
        NCCLTraceBuffer::get()->retire_id(work.trace_id_);
 | 
			
		||||
        NCCLTraceBuffer::get()->retire_id(work.trace_id_, true);
 | 
			
		||||
        if (onCompletionHook_) {
 | 
			
		||||
          // Move Work object to completedWorkList_ to be consumed by the hook
 | 
			
		||||
          // thread
 | 
			
		||||
@ -1475,9 +1626,14 @@ void ProcessGroupNCCL::watchdogHandler() {
 | 
			
		||||
        // completed.
 | 
			
		||||
        ++it;
 | 
			
		||||
      }
 | 
			
		||||
      // Increment heartbeat after each work processed,
 | 
			
		||||
      // in case processing is slowed down (but not hung) by cuda api contention
 | 
			
		||||
      heartbeat_++;
 | 
			
		||||
    }
 | 
			
		||||
    // process a request to dump the trace
 | 
			
		||||
    if (dumpPipe.shouldDump()) {
 | 
			
		||||
    // process a request to dump the trace. only PG uid 0 will respond to dump
 | 
			
		||||
    // requests, but this is fine since all PG's feed into the same flight
 | 
			
		||||
    // recorder and dump.
 | 
			
		||||
    if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
 | 
			
		||||
      launchAsyncDebugDump();
 | 
			
		||||
    }
 | 
			
		||||
    done = workMetaList_.empty();
 | 
			
		||||
@ -1523,8 +1679,8 @@ void ProcessGroupNCCL::runHookLoop() {
 | 
			
		||||
      if (std::string(e.what()).find("driver shutting down") !=
 | 
			
		||||
          std::string::npos) {
 | 
			
		||||
        LOG(INFO)
 | 
			
		||||
            << "[Rank " << rank_
 | 
			
		||||
            << "] main process destroyed cuda before runHookLoop exited, terminating runHookLoop."
 | 
			
		||||
            << logPrefix()
 | 
			
		||||
            << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop."
 | 
			
		||||
            << " (runHookLoop caught exception: " << e.what();
 | 
			
		||||
 | 
			
		||||
      } else {
 | 
			
		||||
@ -1698,7 +1854,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
 | 
			
		||||
  if (bound_device_id_) {
 | 
			
		||||
    for (const auto& device : devices) {
 | 
			
		||||
      if (*bound_device_id_ != device) {
 | 
			
		||||
        LOG(ERROR) << "Tensor found on device " << device
 | 
			
		||||
        LOG(ERROR) << logPrefix() << "Tensor found on device " << device
 | 
			
		||||
                   << " but backend constrained to " << *bound_device_id_;
 | 
			
		||||
        C10_THROW_ERROR(
 | 
			
		||||
            DistBackendError,
 | 
			
		||||
@ -1850,9 +2006,8 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
 | 
			
		||||
 | 
			
		||||
  // At this point NCCL should have been initialized, hence we can accurately
 | 
			
		||||
  // get the env value even if NCCL sets it by reading from nccl.conf file
 | 
			
		||||
  if (getRank() == 0) {
 | 
			
		||||
    LOG(INFO) << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A");
 | 
			
		||||
  }
 | 
			
		||||
  LOG(INFO) << logPrefix()
 | 
			
		||||
            << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A");
 | 
			
		||||
 | 
			
		||||
  // See [Group Start/End Note]
 | 
			
		||||
  for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
 | 
			
		||||
@ -2148,18 +2303,16 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
 | 
			
		||||
  TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
 | 
			
		||||
  TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclStartEvents_->size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup.");
 | 
			
		||||
      ncclStartEvents_,
 | 
			
		||||
      "getDuration only works if ncclStartEvents_ is populated, true if timing enabled");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents_->size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (*ncclEndEvents_)[0].query(),
 | 
			
		||||
      "getDuration can only be called after work is succeeded.")
 | 
			
		||||
  return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
 | 
			
		||||
      ncclEndEvents_,
 | 
			
		||||
      "getDuration only works if ncclEndEvents_ is populated, which should always be true");
 | 
			
		||||
  return getDurationFromFirstEvent(*ncclStartEvents_, *ncclEndEvents_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
 | 
			
		||||
  return seq_;
 | 
			
		||||
}
 | 
			
		||||
@ -3420,14 +3573,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
 | 
			
		||||
    // ensure that each process is on a different GPU
 | 
			
		||||
    auto numGPUs = at::cuda::getNumGPUs();
 | 
			
		||||
    int16_t deviceIdx = static_cast<int16_t>(rank_ % numGPUs);
 | 
			
		||||
    LOG(INFO) << c10::str(
 | 
			
		||||
        "Rank ",
 | 
			
		||||
        this->getRank(),
 | 
			
		||||
        " using GPU ",
 | 
			
		||||
        deviceIdx,
 | 
			
		||||
        " to perform barrier as devices used by this process are currently unknown. ",
 | 
			
		||||
        "This can potentially cause a hang if this rank to GPU mapping is incorrect.",
 | 
			
		||||
        "Specify device_ids in barrier() to force use of a particular device.");
 | 
			
		||||
    LOG(INFO)
 | 
			
		||||
        << logPrefix()
 | 
			
		||||
        << c10::str(
 | 
			
		||||
               " using GPU ",
 | 
			
		||||
               deviceIdx,
 | 
			
		||||
               " to perform barrier as devices used by this process are currently unknown. ",
 | 
			
		||||
               "This can potentially cause a hang if this rank to GPU mapping is incorrect.",
 | 
			
		||||
               "Specify device_ids in barrier() to force use of a particular device.");
 | 
			
		||||
    devices.emplace_back(guessDeviceForRank());
 | 
			
		||||
  } else {
 | 
			
		||||
    for (auto usedDeviceIdx : usedDeviceIdxs_) {
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <future>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
@ -12,6 +13,7 @@
 | 
			
		||||
 | 
			
		||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/Store.hpp>
 | 
			
		||||
 | 
			
		||||
#include <ATen/DynamicLibrary.h>
 | 
			
		||||
@ -26,49 +28,71 @@
 | 
			
		||||
#include <torch/custom_class.h>
 | 
			
		||||
 | 
			
		||||
namespace c10d {
 | 
			
		||||
// Environment variable which controls whether we perform a NCCL healt check
 | 
			
		||||
// Control whether we perform a NCCL health check or not
 | 
			
		||||
// which ensures communicators are healthy at the beginning of init.
 | 
			
		||||
static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = {
 | 
			
		||||
    "TORCH_ENABLE_NCCL_HEALTH_CHECK",
 | 
			
		||||
    "ENABLE_NCCL_HEALTH_CHECK"};
 | 
			
		||||
 | 
			
		||||
// Environment variable which controls whether or not wait() is blocking or
 | 
			
		||||
// non-blocking.
 | 
			
		||||
// Control whether or not wait() is blocking or non-blocking.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
 | 
			
		||||
    "TORCH_NCCL_BLOCKING_WAIT",
 | 
			
		||||
    "NCCL_BLOCKING_WAIT"};
 | 
			
		||||
 | 
			
		||||
// Environment variable which controls whether or not we perform Async Error
 | 
			
		||||
// Handling with NCCL.
 | 
			
		||||
// Control whether or not we perform Async Error Handling with NCCL.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
 | 
			
		||||
    "TORCH_NCCL_ASYNC_ERROR_HANDLING",
 | 
			
		||||
    "NCCL_ASYNC_ERROR_HANDLING"};
 | 
			
		||||
 | 
			
		||||
// Environment Variable to control whether dumping debug info on watchdog
 | 
			
		||||
// Control whether dumping debug info on watchdog
 | 
			
		||||
// timeout is enabled. This variable must be set together with
 | 
			
		||||
// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
 | 
			
		||||
    "TORCH_NCCL_DUMP_ON_TIMEOUT"};
 | 
			
		||||
 | 
			
		||||
// Environment Variable to control whether Desync Debug is enabled.
 | 
			
		||||
// This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
 | 
			
		||||
// Control whether Desync Debug is enabled. This variable must be set
 | 
			
		||||
// together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
 | 
			
		||||
    "TORCH_NCCL_DESYNC_DEBUG",
 | 
			
		||||
    "NCCL_DESYNC_DEBUG"};
 | 
			
		||||
 | 
			
		||||
// Enable recording start-events for all ProcessGroupNCCL collectives, and
 | 
			
		||||
// compute accurate collective timing per-collective. (Note: end-events are
 | 
			
		||||
// recorded by default. Turn on this flag can increase chances of a watchdog
 | 
			
		||||
// hang due to performing a CUDA event query which eventually calls
 | 
			
		||||
// cudaEventElapsedTime() API.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
 | 
			
		||||
    "TORCH_NCCL_ENABLE_TIMING",
 | 
			
		||||
    "NCCL_ENABLE_TIMING"};
 | 
			
		||||
 | 
			
		||||
// Enable monitoring thread which aborts the process when the ProcessGroupNCCL
 | 
			
		||||
// Watchdog thread gets stuck and no heartbeat is detected after
 | 
			
		||||
// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
 | 
			
		||||
// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
 | 
			
		||||
// time than necessary tying up cluster resources.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
 | 
			
		||||
    "TORCH_NCCL_ENABLE_MONITORING"};
 | 
			
		||||
 | 
			
		||||
// Control the watchdog heartbeat timeout period after which the monitoring
 | 
			
		||||
// thread will abort the process.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
 | 
			
		||||
    "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
 | 
			
		||||
 | 
			
		||||
// The maximum number of events we store in the flight recorder's ring buffer.
 | 
			
		||||
// (One event could be the start or end of a collective, for example).
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
 | 
			
		||||
    "TORCH_NCCL_TRACE_BUFFER_SIZE"};
 | 
			
		||||
 | 
			
		||||
// Control how much extra time we will wait for dumping the debugging info
 | 
			
		||||
// before we exit and throws timeout exception.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
 | 
			
		||||
    "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
 | 
			
		||||
 | 
			
		||||
// Control the interval inside the watchdog thread to check the coordinated
 | 
			
		||||
// signal from other ranks, e.g. to dump the debugging information.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
 | 
			
		||||
    "TORCH_NCCL_COORD_CHECK_MILSEC"};
 | 
			
		||||
 | 
			
		||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
 | 
			
		||||
 | 
			
		||||
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
 | 
			
		||||
@ -205,6 +229,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
    uint64_t getSequencenumber() const override;
 | 
			
		||||
 | 
			
		||||
    const std::string& logPrefix() const;
 | 
			
		||||
 | 
			
		||||
    // Helper function that sets an exception_ptr on the WorkNCCL object.
 | 
			
		||||
    void setException(std::exception_ptr exception_ptr);
 | 
			
		||||
 | 
			
		||||
@ -540,8 +566,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  void enableCollectivesTiming() override;
 | 
			
		||||
 | 
			
		||||
  // Provide an API for users to define their own ways to store NCCL debug info.
 | 
			
		||||
  void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
 | 
			
		||||
  // Helper function for iteratively aborting communicators in the provided map
 | 
			
		||||
  void abortCommsFromMap(
 | 
			
		||||
      std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
 | 
			
		||||
          ncclCommsMap,
 | 
			
		||||
      c10::optional<std::string> abortReason);
 | 
			
		||||
 | 
			
		||||
  // Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
 | 
			
		||||
  // instead of relying on ProcessGroupNCCL destructor.
 | 
			
		||||
@ -694,6 +723,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Desync debug helper
 | 
			
		||||
  void logWorkEnd(WorkNCCL& work);
 | 
			
		||||
 | 
			
		||||
  // Generates a prefix that is unique to this process group and rank, for
 | 
			
		||||
  // disambiguating logs
 | 
			
		||||
  std::string createLogPrefix() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the unique prefix created in createLogPrefix
 | 
			
		||||
  const std::string& logPrefix() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the global rank of the device. This function assumes that users
 | 
			
		||||
  // always create a default global process group(PG) which includes all
 | 
			
		||||
  // devices. It is called in the constructor of ProcessGroupNCCL, so it always
 | 
			
		||||
  // return the rank_ of the the very first PG created, aka, default global PG.
 | 
			
		||||
  const int& globalRank() const;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Function that runs as part of a separate thread aside from watchdog
 | 
			
		||||
  // thread because we need to check the heartbeat from watchdog thread
 | 
			
		||||
@ -712,6 +754,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // for dump completion.
 | 
			
		||||
  std::future<bool> launchAsyncDebugDump();
 | 
			
		||||
 | 
			
		||||
  // Helper to wait up to the specified timeout and then abandon the dump.
 | 
			
		||||
  // Logs on timeout, and asserts the future's status is as expected.
 | 
			
		||||
  void waitForDumpOrTimeout(
 | 
			
		||||
      std::future<bool>& fut,
 | 
			
		||||
      const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
 | 
			
		||||
      size_t timeout_sec = 30);
 | 
			
		||||
 | 
			
		||||
  // When watchdog timeout, this function will be called and return debug info
 | 
			
		||||
  // for users. For now we only get information from retrieveDesyncReport.
 | 
			
		||||
  // We are working on enabling more useful debug information for watchdog
 | 
			
		||||
@ -720,9 +769,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  static const int64_t kWatchdogThreadSleepMillis;
 | 
			
		||||
 | 
			
		||||
  // The store is used to broadcast the NCCL unique ID of rank 0.
 | 
			
		||||
  // The store is used to broadcast the NCCL unique ID of rank 0. This store
 | 
			
		||||
  // comes with prefix and it is different across ProcessGroup NCCL instances
 | 
			
		||||
  // (aka, different ProcessGroups).
 | 
			
		||||
  c10::intrusive_ptr<Store> store_;
 | 
			
		||||
 | 
			
		||||
  // Reference to the store without prefix so that keys are same across all
 | 
			
		||||
  // ProcessGroup NCCL instances and (key, value) pairs written to the store are
 | 
			
		||||
  // global.
 | 
			
		||||
  c10::intrusive_ptr<Store> globalStore_;
 | 
			
		||||
 | 
			
		||||
  bool storeError_{false};
 | 
			
		||||
 | 
			
		||||
  const c10::intrusive_ptr<Options> options_;
 | 
			
		||||
@ -781,11 +837,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  std::mutex mutex_;
 | 
			
		||||
 | 
			
		||||
  // Heartbeat of watchdog thread.
 | 
			
		||||
  uint64_t heartbeat_;
 | 
			
		||||
  std::atomic_uint64_t heartbeat_;
 | 
			
		||||
 | 
			
		||||
  // The time interval used for deciding whether there is no watchdog heartbeat.
 | 
			
		||||
  int heartbeatTimeoutInSec_;
 | 
			
		||||
 | 
			
		||||
  // Extra time of sleep when waiting for timeout dump to finish.
 | 
			
		||||
  int waitTimeoutDumpInMilSec_;
 | 
			
		||||
 | 
			
		||||
  // Interval of check coordinated signals in ProcessGroupNCCL from other ranks
 | 
			
		||||
  // e.g., trigger the dump of the debugging info for timeout when notified.
 | 
			
		||||
  int coordCheckIntervalMilSec_;
 | 
			
		||||
 | 
			
		||||
  // Size of ring buffer where we store NCCL Traces for debugging.
 | 
			
		||||
  int ncclTraceBufferSize_;
 | 
			
		||||
 | 
			
		||||
@ -823,9 +886,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  bool writeDebugInfo_ = false;
 | 
			
		||||
 | 
			
		||||
  // Mutex to Guard the check of writeDebugInfo_
 | 
			
		||||
  std::mutex writeDebugInfoMutex_;
 | 
			
		||||
 | 
			
		||||
  // Condition Variable for watchdog thread sleep
 | 
			
		||||
  std::condition_variable workMetaListCV_;
 | 
			
		||||
 | 
			
		||||
@ -929,14 +989,25 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  std::exception_ptr watchDogException_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  // The callback function to store NCCL debug info.
 | 
			
		||||
  std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  size_t uid_;
 | 
			
		||||
 | 
			
		||||
  std::string logPrefix_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_API std::string dump_nccl_trace();
 | 
			
		||||
 | 
			
		||||
// Gets a mutable reference to a global optional function.  Heartbeat Monitor
 | 
			
		||||
// will query this function and if available, call it to dump traces. Inside
 | 
			
		||||
// fbcode, we store a function here that uses an internal tool for process
 | 
			
		||||
// tracing
 | 
			
		||||
TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper();
 | 
			
		||||
 | 
			
		||||
// Similar to get_cpp_trace_dumper, this stores a function defined in
 | 
			
		||||
// torch-python layer that lets us check whether the GIL can be acquired,
 | 
			
		||||
// helpful for instrumenting in cases where a hang was observed.
 | 
			
		||||
typedef bool (*gil_checker_t)();
 | 
			
		||||
 | 
			
		||||
TORCH_API gil_checker_t& get_gil_checker();
 | 
			
		||||
} // namespace c10d
 | 
			
		||||
 | 
			
		||||
#endif // USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
@ -269,10 +269,20 @@ inline std::string retrieveDesyncReport(
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter::DebugInfoWriter(int rank) {
 | 
			
		||||
  std::string fileName = getCvarString(
 | 
			
		||||
      {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
  filename_ = c10::str(fileName, rank);
 | 
			
		||||
/* Helper used by work::getDuration() and nccl flight recorder */
 | 
			
		||||
float getDurationFromFirstEvent(
 | 
			
		||||
    const std::vector<at::cuda::CUDAEvent>& ncclStartEvents,
 | 
			
		||||
    const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclStartEvents.size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup, but found multiple start events.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents.size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup, but found multiple end events.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents[0].query(),
 | 
			
		||||
      "getDuration can only be called after work is succeeded.")
 | 
			
		||||
  return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter::~DebugInfoWriter() = default;
 | 
			
		||||
@ -293,6 +303,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) {
 | 
			
		||||
  LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
 | 
			
		||||
  if (writer_ == nullptr) {
 | 
			
		||||
    std::string fileNamePrefix = getCvarString(
 | 
			
		||||
        {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
    // Using std::unique_ptr here to auto-delete the writer object
 | 
			
		||||
    // when the pointer itself is destroyed.
 | 
			
		||||
    std::unique_ptr<DebugInfoWriter> writerPtr(
 | 
			
		||||
        new DebugInfoWriter(fileNamePrefix, rank));
 | 
			
		||||
    DebugInfoWriter::registerWriter(std::move(writerPtr));
 | 
			
		||||
  }
 | 
			
		||||
  return *writer_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
 | 
			
		||||
  TORCH_CHECK_WITH(
 | 
			
		||||
      DistBackendError,
 | 
			
		||||
      hasWriterRegistered_.load() == false,
 | 
			
		||||
      "debugInfoWriter already registered");
 | 
			
		||||
  hasWriterRegistered_.store(true);
 | 
			
		||||
  writer_ = std::move(writer);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
 | 
			
		||||
std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
 | 
			
		||||
 | 
			
		||||
inline std::string pickle_str(const c10::IValue& v) {
 | 
			
		||||
  std::vector<char> result;
 | 
			
		||||
  {
 | 
			
		||||
@ -340,7 +375,7 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    const char* profiling_name_;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<torch::CapturedTraceback> traceback_;
 | 
			
		||||
    // we borrow pointser to start_ and end_ so we can query the state
 | 
			
		||||
    // we borrow pointers to start_ and end_ so we can query the state
 | 
			
		||||
    // on reporting. However, once the event is completed, the call
 | 
			
		||||
    // to `complete` will clear these.
 | 
			
		||||
    EventList *start_, *end_;
 | 
			
		||||
@ -348,6 +383,7 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    // timestamp when the entry was created, likely close to the time the work
 | 
			
		||||
    // was 'enqueued'- not necessarily started
 | 
			
		||||
    c10::time_t time_created_;
 | 
			
		||||
    c10::optional<float> duration_;
 | 
			
		||||
 | 
			
		||||
    const char* state_ = "scheduled";
 | 
			
		||||
 | 
			
		||||
@ -386,7 +422,7 @@ struct NCCLTraceBuffer {
 | 
			
		||||
        id_,
 | 
			
		||||
        pg_id,
 | 
			
		||||
        seq_id,
 | 
			
		||||
        profiling_name,
 | 
			
		||||
        profiling_name == nullptr ? "" : profiling_name,
 | 
			
		||||
        std::move(traceback),
 | 
			
		||||
        std::move(start),
 | 
			
		||||
        std::move(end),
 | 
			
		||||
@ -456,35 +492,90 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void retire_id(c10::optional<size_t> id) {
 | 
			
		||||
  /*
 | 
			
		||||
  Mark an Event as completed and free its events.
 | 
			
		||||
 | 
			
		||||
  This is called by the watchdog thread, and is asynchronous from the
 | 
			
		||||
  perspective of the main thread.
 | 
			
		||||
 | 
			
		||||
  compute_duration defaults to true since retire_id is only called in the
 | 
			
		||||
  watchdog thread, which is currently a place we call cuda APIs which may hang,
 | 
			
		||||
  but care should be taken to avoid computing duration in any function that must
 | 
			
		||||
  never hang. (timing must also be enabled for compute_duration - see
 | 
			
		||||
  TORCH_NCCL_ENABLE_TIMING).
 | 
			
		||||
  */
 | 
			
		||||
  void retire_id(c10::optional<size_t> id, bool compute_duration = true) {
 | 
			
		||||
    if (!enabled_ || !id) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    std::lock_guard<std::mutex> guard(mutex_);
 | 
			
		||||
 | 
			
		||||
    bool can_compute_duration = false;
 | 
			
		||||
    EventList* startEvents = nullptr;
 | 
			
		||||
    EventList* endEvents = nullptr;
 | 
			
		||||
    c10::optional<float> duration = c10::nullopt;
 | 
			
		||||
 | 
			
		||||
    std::unique_lock<std::mutex> guard(mutex_);
 | 
			
		||||
 | 
			
		||||
    auto& entry = entries_.at(*id % max_entries_);
 | 
			
		||||
    if (entry.id_ == *id) {
 | 
			
		||||
      update_state(entry);
 | 
			
		||||
      entry.retired_ = true;
 | 
			
		||||
      entry.start_ = entry.end_ = nullptr;
 | 
			
		||||
 | 
			
		||||
      if (compute_duration) {
 | 
			
		||||
        can_compute_duration = strcmp(entry.state_, "completed") == 0 &&
 | 
			
		||||
            entry.start_ && entry.end_;
 | 
			
		||||
        startEvents = entry.start_;
 | 
			
		||||
        endEvents = entry.end_;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (can_compute_duration) {
 | 
			
		||||
      // Compute duration without without holding the lock, because
 | 
			
		||||
      // cudaEventDuration() can hang, and we need to acquire the lock before we
 | 
			
		||||
      // can dump(), which we never want to block.
 | 
			
		||||
      guard.unlock();
 | 
			
		||||
      duration = getDurationFromFirstEvent(*startEvents, *endEvents);
 | 
			
		||||
      guard.lock();
 | 
			
		||||
 | 
			
		||||
      // Refresh the entry ref, see if it has been overwritten
 | 
			
		||||
      entry = entries_.at(*id % max_entries_);
 | 
			
		||||
      if (entry.id_ != *id) {
 | 
			
		||||
        LOG(INFO)
 | 
			
		||||
            << "retire_id abandoned for id " << *id
 | 
			
		||||
            << ", event was overwritten while waiting to compute duration.";
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      if (duration.has_value()) {
 | 
			
		||||
        entry.duration_ = duration.value();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    entry.retired_ = true;
 | 
			
		||||
    entry.start_ = entry.end_ = nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string dump() {
 | 
			
		||||
    auto result = dump_entries();
 | 
			
		||||
    auto entries = new_list();
 | 
			
		||||
    c10::IValue pg_id_s = "pg_id";
 | 
			
		||||
    c10::IValue seq_id_s = "seq_id";
 | 
			
		||||
    c10::IValue profiling_name_s = "profiling_name";
 | 
			
		||||
    c10::IValue input_sizes_s = "input_sizes";
 | 
			
		||||
    c10::IValue output_sizes_s = "output_sizes";
 | 
			
		||||
    c10::IValue time_created_s = "time_created_us";
 | 
			
		||||
    c10::IValue entries_key = "entries";
 | 
			
		||||
    c10::IValue version_key = "version";
 | 
			
		||||
    // Update whenever changing contents or formatting of the dump
 | 
			
		||||
    // (minor when adding fields, major when changing existing fields)
 | 
			
		||||
    c10::IValue version_val = "1.0";
 | 
			
		||||
 | 
			
		||||
    c10::IValue frames_s = "frames";
 | 
			
		||||
    c10::IValue state_s = "state";
 | 
			
		||||
    c10::IValue line_s = "line";
 | 
			
		||||
    c10::IValue name_s = "name";
 | 
			
		||||
    c10::IValue filename_s = "filename";
 | 
			
		||||
    c10::IValue retired_s = "retired";
 | 
			
		||||
    c10::IValue pg_id_key = "pg_id";
 | 
			
		||||
    c10::IValue seq_id_key = "seq_id";
 | 
			
		||||
    c10::IValue profiling_name_key = "profiling_name";
 | 
			
		||||
    c10::IValue input_sizes_key = "input_sizes";
 | 
			
		||||
    c10::IValue output_sizes_key = "output_sizes";
 | 
			
		||||
    c10::IValue time_created_key = "time_created_ns";
 | 
			
		||||
    c10::IValue duration_key = "duration_ms";
 | 
			
		||||
 | 
			
		||||
    c10::IValue frames_key = "frames";
 | 
			
		||||
    c10::IValue state_key = "state";
 | 
			
		||||
    c10::IValue line_key = "line";
 | 
			
		||||
    c10::IValue name_key = "name";
 | 
			
		||||
    c10::IValue filename_key = "filename";
 | 
			
		||||
    c10::IValue retired_key = "retired";
 | 
			
		||||
 | 
			
		||||
    std::vector<torch::CapturedTraceback*> tracebacks;
 | 
			
		||||
    for (auto& e : result) {
 | 
			
		||||
@ -494,9 +585,9 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    std::vector<c10::IValue> all_frames;
 | 
			
		||||
    for (const auto& f : stracebacks.all_frames) {
 | 
			
		||||
      auto d = new_dict();
 | 
			
		||||
      d.insert(name_s, f.funcname);
 | 
			
		||||
      d.insert(filename_s, f.filename);
 | 
			
		||||
      d.insert(line_s, int64_t(f.lineno));
 | 
			
		||||
      d.insert(name_key, f.funcname);
 | 
			
		||||
      d.insert(filename_key, f.filename);
 | 
			
		||||
      d.insert(line_key, int64_t(f.lineno));
 | 
			
		||||
      all_frames.emplace_back(std::move(d));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -504,10 +595,13 @@ struct NCCLTraceBuffer {
 | 
			
		||||
      auto& e = result.at(i);
 | 
			
		||||
      auto& tb = stracebacks.tracebacks.at(i);
 | 
			
		||||
      auto dict = new_dict();
 | 
			
		||||
      dict.insert(pg_id_s, int64_t(e.pg_id_));
 | 
			
		||||
      dict.insert(seq_id_s, int64_t(e.seq_id_));
 | 
			
		||||
      dict.insert(profiling_name_s, e.profiling_name_);
 | 
			
		||||
      dict.insert(time_created_s, int64_t(e.time_created_ / 1000));
 | 
			
		||||
      dict.insert(pg_id_key, int64_t(e.pg_id_));
 | 
			
		||||
      dict.insert(seq_id_key, int64_t(e.seq_id_));
 | 
			
		||||
      dict.insert(profiling_name_key, e.profiling_name_);
 | 
			
		||||
      dict.insert(time_created_key, int64_t(e.time_created_));
 | 
			
		||||
      if (e.duration_) {
 | 
			
		||||
        dict.insert(duration_key, *e.duration_);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto it = e.sizes_.begin();
 | 
			
		||||
      auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
 | 
			
		||||
@ -523,19 +617,24 @@ struct NCCLTraceBuffer {
 | 
			
		||||
        return sizes;
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      dict.insert(input_sizes_s, read_sizes(e.input_dims_));
 | 
			
		||||
      dict.insert(output_sizes_s, read_sizes(e.output_dims_));
 | 
			
		||||
      dict.insert(state_s, e.state_);
 | 
			
		||||
      dict.insert(retired_s, e.retired_);
 | 
			
		||||
      dict.insert(input_sizes_key, read_sizes(e.input_dims_));
 | 
			
		||||
      dict.insert(output_sizes_key, read_sizes(e.output_dims_));
 | 
			
		||||
      dict.insert(state_key, e.state_);
 | 
			
		||||
      dict.insert(retired_key, e.retired_);
 | 
			
		||||
 | 
			
		||||
      auto frames = new_list();
 | 
			
		||||
      for (int64_t frame : tb) {
 | 
			
		||||
        frames.push_back(all_frames.at(frame));
 | 
			
		||||
      }
 | 
			
		||||
      dict.insert(frames_s, frames);
 | 
			
		||||
      dict.insert(frames_key, frames);
 | 
			
		||||
      entries.push_back(dict);
 | 
			
		||||
    }
 | 
			
		||||
    return pickle_str(entries);
 | 
			
		||||
 | 
			
		||||
    auto dict = new_dict();
 | 
			
		||||
    dict.insert(entries_key, entries);
 | 
			
		||||
    dict.insert(version_key, version_val);
 | 
			
		||||
 | 
			
		||||
    return pickle_str(dict);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,35 @@
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
bool acquire_gil() {
 | 
			
		||||
  // basically if this function can acquire the gil, it will return quickly.
 | 
			
		||||
  // if not, it will hang forever.  The idea is to call this from a thread
 | 
			
		||||
  // wrapped in a future, and then check the future after a timeout, to
 | 
			
		||||
  // determine whether we're facing gil contention.
 | 
			
		||||
  if (Py_IsInitialized()) {
 | 
			
		||||
    pybind11::gil_scoped_acquire gil;
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If we end up here, its probably still a "pass" from the perspective of
 | 
			
		||||
  // checking whether python is stuck. but currently we don't check the return
 | 
			
		||||
  // value of this function anyway, just check whether it returned quickly vs
 | 
			
		||||
  // timing out.  Taking a long time is the main sign of trouble.  Fast return
 | 
			
		||||
  // with true or with false is both OK from the perspective of debugging python
 | 
			
		||||
  // hangs.
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool registerGilChecker() {
 | 
			
		||||
  c10d::get_gil_checker() = &acquire_gil;
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool registered = registerGilChecker();
 | 
			
		||||
#endif // USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
// Wrapper to ensure GIL is released before destructing ProcessGroupGloo
 | 
			
		||||
// TODO: move this somewhere more generally useful
 | 
			
		||||
template <typename T>
 | 
			
		||||
@ -1033,6 +1062,29 @@ Example::
 | 
			
		||||
    >>> store.add("first_key", 6)
 | 
			
		||||
    >>> # Should return 7
 | 
			
		||||
    >>> store.get("first_key")
 | 
			
		||||
)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "check",
 | 
			
		||||
              &::c10d::Store::check,
 | 
			
		||||
              py::call_guard<py::gil_scoped_release>(),
 | 
			
		||||
              R"(
 | 
			
		||||
The call to check whether a given list of ``keys`` have value stored in
 | 
			
		||||
the store. This call immediately returns in normal cases but still suffers
 | 
			
		||||
from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
 | 
			
		||||
Calling :meth:`~torch.distributed.store.check` with a list of keys that
 | 
			
		||||
one wants to check whether stored in the store or not.
 | 
			
		||||
 | 
			
		||||
Arguments:
 | 
			
		||||
    keys (lisr[str]): The keys to query whether stored in the store.
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
    >>> import torch.distributed as dist
 | 
			
		||||
    >>> from datetime import timedelta
 | 
			
		||||
    >>> # Using TCPStore as an example, other store types can also be used
 | 
			
		||||
    >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
 | 
			
		||||
    >>> store.add("first_key", 1)
 | 
			
		||||
    >>> # Should return 7
 | 
			
		||||
    >>> store.check(["first_key"])
 | 
			
		||||
)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "delete_key",
 | 
			
		||||
@ -1404,7 +1456,11 @@ Arguments:
 | 
			
		||||
      .def_property_readonly(
 | 
			
		||||
          "underlying_store",
 | 
			
		||||
          &::c10d::PrefixStore::getUnderlyingStore,
 | 
			
		||||
          R"(Gets the underlying store object that PrefixStore wraps around.)");
 | 
			
		||||
          R"(Gets the underlying store object that PrefixStore wraps around.)")
 | 
			
		||||
      .def_property_readonly(
 | 
			
		||||
          "_underlying_non_prefix_store",
 | 
			
		||||
          &::c10d::PrefixStore::getUnderlyingNonPrefixStore,
 | 
			
		||||
          R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
 | 
			
		||||
 | 
			
		||||
  auto processGroup =
 | 
			
		||||
      py::class_<
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user