mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	Compare commits
	
		
			48 Commits
		
	
	
		
			aoti-cuda-
			...
			sqzhang_fl
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| dde4324d8e | |||
| 94c079104d | |||
| a6afee6d94 | |||
| d092857531 | |||
| 6aad5e444a | |||
| c54ce9313b | |||
| 1fe59f4ef7 | |||
| e693fb2bb1 | |||
| 4fe510baf6 | |||
| 7c507b78c4 | |||
| 0019901601 | |||
| 18be18535b | |||
| 2729367313 | |||
| 33537aae24 | |||
| dcdb1337dd | |||
| 9cf0f2bd59 | |||
| 1d2e877c05 | |||
| f27b979b0c | |||
| f30d6047ad | |||
| 75311510ef | |||
| dbd6094d05 | |||
| 397b9d47e9 | |||
| 36a01a8ab9 | |||
| ee336cf58a | |||
| a9e2e745d7 | |||
| ab4df89eea | |||
| 9d02ebe876 | |||
| b61e01cce9 | |||
| f7ce61ba53 | |||
| e7bae15ab1 | |||
| e71b422908 | |||
| 389940ce60 | |||
| b2237a7c85 | |||
| ef5dfe3f3e | |||
| e303dc3c08 | |||
| 265efad2de | |||
| 60f0455905 | |||
| 4898313791 | |||
| f4da9adf6b | |||
| 8f7f35273e | |||
| 44ec9612ed | |||
| 4d3bea2b29 | |||
| 0bcdddc3c1 | |||
| 28b6220312 | |||
| 210b7b65e2 | |||
| 4da10b5cd3 | |||
| f09763814f | |||
| 80923ed5a6 | 
| @ -1732,7 +1732,7 @@ if(BUILD_TEST) | ||||
|   foreach(test_src ${Caffe2_CPU_TEST_SRCS}) | ||||
|     get_filename_component(test_name ${test_src} NAME_WE) | ||||
|     add_executable(${test_name} "${test_src}") | ||||
|     target_link_libraries(${test_name} torch_library gtest_main) | ||||
|     target_link_libraries(${test_name} torch_library gtest_main stdc++) | ||||
|     target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>) | ||||
|     target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>) | ||||
|     target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) | ||||
|  | ||||
| @ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) { | ||||
| // Extend the nested class outside the parent class | ||||
| class TestDebugInfoWriter : public c10d::DebugInfoWriter { | ||||
|  public: | ||||
|   TestDebugInfoWriter() : DebugInfoWriter(0) {} | ||||
|   TestDebugInfoWriter(std::string namePrefix) | ||||
|       : DebugInfoWriter(namePrefix, 0) {} | ||||
|  | ||||
|   void write(const std::string& ncclTrace) override { | ||||
|     traces_.assign(ncclTrace.begin(), ncclTrace.end()); | ||||
| @ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { | ||||
|   // The storer here is very similar to the fallback storer. | ||||
|   // The only difference is that we are storing traces also in memory for | ||||
|   // validation. | ||||
|   std::string fileNamePrefix = c10d::getCvarString( | ||||
|       {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|   std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr = | ||||
|       std::make_unique<TestDebugInfoWriter>(); | ||||
|       std::make_unique<TestDebugInfoWriter>(fileNamePrefix); | ||||
|   std::vector<uint8_t>& traces = wrterForTestPtr->getTraces(); | ||||
|   pg.registerDebugInfoWriter(std::move(wrterForTestPtr)); | ||||
|   c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr)); | ||||
|  | ||||
|   // Normal collective case. | ||||
|   auto work = pg.allreduce(tensors_); | ||||
| @ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { | ||||
| class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest { | ||||
|  protected: | ||||
|   void SetUp() override { | ||||
|     // TODO (kwen2501) | ||||
|     GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; " | ||||
|                  << "will rewrite them after refactoring Work queues."; | ||||
|     ProcessGroupNCCLErrorsTest::SetUp(); | ||||
|     std::string timeInterval = std::to_string(heartBeatIntervalInSec); | ||||
|     ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); | ||||
|  | ||||
| @ -3542,11 +3542,12 @@ class SparseCollective(MultiProcessTestCase): | ||||
| class NCCLTraceTestBase(MultiProcessTestCase): | ||||
|     def setUp(self): | ||||
|         super().setUp() | ||||
|         os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0' | ||||
|         os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'  # see 'timing_enabled' parametrized tests | ||||
|         os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10' | ||||
|         os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1' | ||||
|         self.tempdir = tempfile.TemporaryDirectory() | ||||
|         os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename() | ||||
|         os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename() | ||||
|         self._spawn_processes() | ||||
|  | ||||
|     @classmethod | ||||
| @ -3617,28 +3618,49 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
|     def test_short(self): | ||||
|     @parametrize("timing_enabled", [True, False]) | ||||
|     def test_short(self, timing_enabled): | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             return | ||||
|         pg = self._create_process_group_nccl() | ||||
|         if timing_enabled: | ||||
|             pg._enable_collectives_timing() | ||||
|         device = self.local_device | ||||
|         a = torch.full((3, 4), float(self.rank), device=device) | ||||
|         for i in range(2): | ||||
|             f = pg.allreduce(a) | ||||
|         f.wait() | ||||
|         torch.cuda.synchronize(device=device) | ||||
|  | ||||
|         # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api | ||||
|         time.sleep(1) | ||||
|  | ||||
|         t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|         ver = t['version'] | ||||
|         self.assertEqual(ver, "1.1") | ||||
|         t = t['entries'] | ||||
|         self.assertEqual(len(t), 2) | ||||
|         last = t[-1] | ||||
|         self.assertEqual(last['state'], 'completed') | ||||
|         s = last['time_discovered_started_ns'] | ||||
|         f = last['time_discovered_completed_ns'] | ||||
|         self.assertIsNotNone(f) | ||||
|         if timing_enabled: | ||||
|             self.assertIsNotNone(s) | ||||
|             self.assertTrue(s <= f) | ||||
|         self.assertIn('test_c10d_nccl.py', str(last['frames'])) | ||||
|         self.assertEqual(last['input_sizes'], ((3, 4),)) | ||||
|         self.assertEqual(last['output_sizes'], ((3, 4),)) | ||||
|         self.assertEqual(last['seq_id'], 2) | ||||
|         now = datetime.now() | ||||
|         event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000) | ||||
|         event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000) | ||||
|         before_test = now - timedelta(minutes=1) | ||||
|         self.assertTrue(before_test < event_created_time < now) | ||||
|         if timing_enabled: | ||||
|             # very loose bounds, measured 0.036 ms on devgpu | ||||
|             self.assertTrue(0 < last['duration_ms'] < 100) | ||||
|         else: | ||||
|             self.assertTrue("duration_ms" not in last) | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
| @ -3698,9 +3720,11 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|         f.wait() | ||||
|         torch.cuda.synchronize(device=device) | ||||
|         t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|         t = t['entries'] | ||||
|         self.assertEqual(len(t), 10) | ||||
|         first = t[0] | ||||
|         last = t[-1] | ||||
|         self.assertEqual(last['profiling_name'], 'nccl:all_reduce') | ||||
|         self.assertEqual(last['state'], 'completed') | ||||
|         self.assertIn('test_c10d_nccl.py', str(last['frames'])) | ||||
|         self.assertEqual(last['input_sizes'], ((3, 4),)) | ||||
| @ -3732,6 +3756,8 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|                 pg.allreduce(a).wait() | ||||
|             e.synchronize() | ||||
|             t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|             t = t['entries'] | ||||
|             self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce') | ||||
|             if self.rank == 0: | ||||
|                 self.assertEqual(t[-1]['seq_id'], 1) | ||||
|                 self.assertEqual(t[-1]['state'], 'completed') | ||||
| @ -3773,12 +3799,15 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|                 # give the other thread some time to fill the cuda buffer | ||||
|                 time.sleep(5) | ||||
|                 t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce') | ||||
|                 if self.rank == 0: | ||||
|                     self.assertEqual(t[-1]['seq_id'], 1) | ||||
|                     self.assertEqual(t[-1]['state'], 'completed') | ||||
|                 else: | ||||
|                     self.assertEqual(t[-1]['seq_id'], 2) | ||||
|                     self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled)) | ||||
|                     self.assertIsNone(t[-1]['time_discovered_completed_ns']) | ||||
|                 # this will eventually cause the missing rank 0 | ||||
|                 # to continue which will unblock the non-zero ranks | ||||
|                 self.parent.send('next') | ||||
| @ -3832,6 +3861,9 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase): | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
|     @parametrize("timing_enabled", [True, False]) | ||||
|     def test_timeout_dumps(self, timing_enabled): | ||||
|         # We need to completely disable the coordinated timeout dump to avoid rank 0 | ||||
|         # also timeout so that we set the check frequency to be very large (25 min). | ||||
|         os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1500000' | ||||
|  | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             # wait for rank0 to crash before looking for its output file | ||||
| @ -3839,6 +3871,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase): | ||||
|             self.assertEqual(self._wait_process(0, timeout=90), -6) | ||||
|             with open(self._trace_name(rank=0), 'rb') as f: | ||||
|                 t = pickle.load(f) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(len(t), 2) | ||||
|                 self.assertEqual(t[0]['seq_id'], 1) | ||||
|                 self.assertEqual(t[0]['state'], 'completed') | ||||
| @ -3868,7 +3901,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase): | ||||
| instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout) | ||||
| instantiate_parametrized_tests(NCCLTraceTest) | ||||
|  | ||||
| class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
| class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|     def _check_return_codes(self, elapsed_time): | ||||
|         # the base test infra assumes processes exit with matching return codes, | ||||
|         # but we want rank0 to abort and rank1 to exit cleanly in this test | ||||
| @ -3877,7 +3910,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
|     def test_timeout_dumps_on_idle_ranks(self): | ||||
|     def test_timeout_dumps_on_stuck_ranks(self): | ||||
|  | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             # wait for both rank0 and 1 to crash before looking for both ranks' output | ||||
| @ -3888,20 +3921,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|             self.assertTrue(os.path.exists(self._trace_name(rank=0))) | ||||
|             with open(self._trace_name(rank=0), 'rb') as f: | ||||
|                 t = pickle.load(f) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(len(t), 2) | ||||
|             with open(self._trace_name(rank=1), 'rb') as f: | ||||
|                 t = pickle.load(f) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(len(t), 1) | ||||
|                 self.assertEqual(t[0]['seq_id'], 1) | ||||
|                 self.assertEqual(t[0]['state'], 'completed') | ||||
|             return | ||||
|  | ||||
|         # Set heartbeat timeout to a shorter one (default timeout is 2 min). | ||||
|         os.environ[ | ||||
|             "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC" | ||||
|         ] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}" | ||||
|         pg = self._create_process_group_nccl() | ||||
|  | ||||
|         device = self.local_device | ||||
|         with torch.cuda.device(device): | ||||
|             a = torch.full((3, 4), float(self.rank), device=device) | ||||
| @ -3910,12 +3940,14 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|             if self.rank == 0: | ||||
|                 pg.allreduce(a).wait() | ||||
|  | ||||
|             # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly | ||||
|             # rank 0 will get stuck, timeout and then signal a timeout to all ranks. | ||||
|             torch.cuda.synchronize() | ||||
|  | ||||
|             # Force rank 1 to idle so that it also gets debug info dump triggered. | ||||
|             if self.rank == 1: | ||||
|                 time.sleep(6) | ||||
|                 # Force rank 1 to idle so that it will eventually timeout as well after | ||||
|                 # getting the global signal to dump the debugging info. | ||||
|                 time.sleep(600) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     assert ( | ||||
|  | ||||
| @ -71,7 +71,7 @@ class StoreTestBase: | ||||
|     def _create_store(self, i): | ||||
|         raise RuntimeError("not implemented") | ||||
|  | ||||
|     def _test_set_get(self, fs): | ||||
|     def _test_set_get_check(self, fs): | ||||
|         fs.add("key", 1) | ||||
|         fs.add("key", 2) | ||||
|         fs.add("key", 3) | ||||
| @ -90,14 +90,16 @@ class StoreTestBase: | ||||
|         self.assertEqual(b"value1", fs.get("key1")) | ||||
|         self.assertEqual(b"value2", fs.get("key2")) | ||||
|         self.assertEqual(b"21", fs.get("key3")) | ||||
|         self.assertTrue(fs.check(["key3"])) | ||||
|         self.assertFalse(fs.check(["Randomkey3"])) | ||||
|  | ||||
|         fs.set("-key3", "7") | ||||
|         self.assertEqual(b"7", fs.get("-key3")) | ||||
|         fs.delete_key("-key3") | ||||
|         self.assertEqual(fs.num_keys(), self.num_keys_total) | ||||
|  | ||||
|     def test_set_get(self): | ||||
|         self._test_set_get(self._create_store()) | ||||
|     def test_set_get_check(self): | ||||
|         self._test_set_get_check(self._create_store()) | ||||
|  | ||||
|     def _test_compare_set(self, store): | ||||
|         missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") | ||||
| @ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase): | ||||
|     def num_keys_total(self): | ||||
|         return 6 | ||||
|  | ||||
|     def test_underlying_non_prefix_store(self): | ||||
|         store = self._create_store() | ||||
|         wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store)) | ||||
|         self.assertEqual(self.tcpstore, store._underlying_non_prefix_store) | ||||
|         self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store) | ||||
|  | ||||
| class MyPythonStore(dist.Store): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
|  | ||||
| #include <ATen/ATen.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/Optional.h> | ||||
| #include <nccl.h> | ||||
| @ -175,14 +176,29 @@ std::string getNcclErrorDetailStr( | ||||
|     c10::optional<std::string> processGroupFailureReason = c10::nullopt); | ||||
|  | ||||
| // Write NCCL debug info to local disk or any storage users define. | ||||
| // There are some constrains we set for the debug info writer: | ||||
| // 1. The writer should only be registered once. | ||||
| // 2. Once registered, users cannot change it including un-register. | ||||
| // 3. It is recommended to register the customized writer in the trainer setup, | ||||
| //    If users don't register before calling launchAsyncDebugDump, then users | ||||
| //    lose the chance to register (and the default writer will be | ||||
| //    auto-registered). | ||||
| class TORCH_API DebugInfoWriter { | ||||
|  public: | ||||
|   DebugInfoWriter(int rank); | ||||
|   virtual ~DebugInfoWriter(); | ||||
|   virtual void write(const std::string& ncclTrace); | ||||
|   static DebugInfoWriter& getWriter(int rank); | ||||
|   static void registerWriter(std::unique_ptr<DebugInfoWriter> writer); | ||||
|  | ||||
|  protected: | ||||
|   DebugInfoWriter(std::string namePrefix, int rank) { | ||||
|     filename_ = c10::str(namePrefix, rank); | ||||
|   } | ||||
|   std::string filename_; | ||||
|  | ||||
|  private: | ||||
|   static std::unique_ptr<DebugInfoWriter> writer_; | ||||
|   static std::atomic<bool> hasWriterRegistered_; | ||||
| }; | ||||
|  | ||||
| // RAII wrapper for NCCL communicator | ||||
| @ -267,6 +283,18 @@ class NCCLComm { | ||||
|   } | ||||
| #endif | ||||
|  | ||||
| #if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP) | ||||
|   std::unordered_map<std::string, std::string> ncclCommDump() { | ||||
|     std::unordered_map<std::string, std::string> dump; | ||||
|     if (isAborted()) { | ||||
|       LOG(INFO) << "Communicator was aborted before trying to dump its state."; | ||||
|       return dump; | ||||
|     } | ||||
|     C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt); | ||||
|     return dump; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   ncclUniqueId getNcclId() { | ||||
|     return ncclId_; | ||||
|   } | ||||
| @ -322,6 +350,9 @@ class NCCLComm { | ||||
|     // Set true failure reason if provided by ProcessGroupNCCL (e.g. work | ||||
|     // timeout) | ||||
|     commFailureReason_ = commFailureReason; | ||||
|     LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " | ||||
|               << (commFailureReason ? *commFailureReason | ||||
|                                     : "No abort reason provided."); | ||||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||||
|     C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); | ||||
| #else | ||||
| @ -421,6 +452,8 @@ class NCCLComm { | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   friend class ProcessGroupNCCL; | ||||
|  | ||||
|  protected: | ||||
|   ncclComm_t ncclComm_; | ||||
|   // Unique nccl_id for this communicator. | ||||
|  | ||||
| @ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() { | ||||
|   return store_; | ||||
| } | ||||
|  | ||||
| c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() { | ||||
|   c10::intrusive_ptr<Store> store = store_; | ||||
|  | ||||
|   while (store) { | ||||
|     // Attempt to dynamically cast to PrefixStore | ||||
|     PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get()); | ||||
|     if (asPrefixStore) { | ||||
|       store = asPrefixStore->getUnderlyingStore(); | ||||
|     } else { | ||||
|       break; // We've reached a non-PrefixStore | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       store != nullptr, "Underlying Non-PrefixStore shouldn't be null."); | ||||
|   return store; | ||||
| } | ||||
|  | ||||
| } // namespace c10d | ||||
|  | ||||
| @ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store { | ||||
|  | ||||
|   c10::intrusive_ptr<Store> getUnderlyingStore(); | ||||
|  | ||||
|   // Recursively to fetch the store before layers of wrapping with PrefixStore. | ||||
|   c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore(); | ||||
|  | ||||
|  protected: | ||||
|   std::string prefix_; | ||||
|   c10::intrusive_ptr<Store> store_; | ||||
|  | ||||
| @ -337,8 +337,62 @@ void cacheAllocatorDeregisterHook( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP) | ||||
| std::string dump_nccl_trace() { | ||||
|   return NCCLTraceBuffer::get()->dump(); | ||||
|   std::unordered_map< | ||||
|       std::string /* ncclUniqueID */, | ||||
|       std::unordered_map<std::string, std::string> /* dump from this comm */> | ||||
|       ncclDumpMap; | ||||
|   // dump_nccl_trace is only called from the default PG (uid_=0), but we want to | ||||
|   // dump from all comms so we need to iterate over ncclCommDevIdxMap, which | ||||
|   // is static | ||||
|   std::vector<std::shared_ptr<NCCLComm>> allNCCLComms; | ||||
|   // within the critical section, we don't want to dump while holding the lock | ||||
|   // as dump might hang | ||||
|   ncclCommDevIdxMapMutex.lock(); | ||||
|   for (auto& [ncclComm, _] : ncclCommDevIdxMap) { | ||||
|     allNCCLComms.push_back(ncclComm); | ||||
|   } | ||||
|   ncclCommDevIdxMapMutex.unlock(); | ||||
|   for (auto& ncclComm : allNCCLComms) { | ||||
|     std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); | ||||
|     ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); | ||||
|   } | ||||
|   return NCCLTraceBuffer::get()->dump(ncclDumpMap); | ||||
| } | ||||
| #else | ||||
| std::string dump_nccl_trace() { | ||||
|   return NCCLTraceBuffer::get()->dump(c10::nullopt); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| c10::optional<std::function<std::string()>>& get_cpp_trace_dumper() { | ||||
|   static c10::optional<std::function<std::string()>> dumper(c10::nullopt); | ||||
|   return dumper; | ||||
| } | ||||
|  | ||||
| gil_checker_t& get_gil_checker() { | ||||
|   static gil_checker_t gil_checker = nullptr; | ||||
|   return gil_checker; | ||||
| } | ||||
|  | ||||
| std::future<bool> launchAsyncGilCheck() { | ||||
|   std::promise<bool> resultPromise; | ||||
|   std::future<bool> resultFuture = resultPromise.get_future(); | ||||
|   TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker"); | ||||
|   std::thread workerThread([promise = std::move(resultPromise)]() mutable { | ||||
|     try { | ||||
|       auto& gil_checker = get_gil_checker(); | ||||
|       promise.set_value((*gil_checker)()); | ||||
|     } catch (...) { | ||||
|       promise.set_exception(std::current_exception()); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // Detach the thread to allow it to run independently | ||||
|   workerThread.detach(); | ||||
|  | ||||
|   return resultFuture; | ||||
| } | ||||
|  | ||||
| // Return CUDA device with ordinal given by input rank.  If we aren't | ||||
| @ -358,7 +412,7 @@ at::Device ProcessGroupNCCL::guessDeviceForRank() const { | ||||
|   } | ||||
| } | ||||
|  | ||||
| const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000; | ||||
| const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; | ||||
| constexpr int64_t kSynchronizeBusyWaitMillis = 10; | ||||
| thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; | ||||
|  | ||||
| @ -466,12 +520,17 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { | ||||
|   std::unique_lock<std::mutex> lock(mutex_); | ||||
|   exception_ = exception_ptr; | ||||
|   if (exception_) { | ||||
|     LOG(INFO) << "[Rank " << rank_ << "]" | ||||
|               << " found async exception when checking for NCCL errors: " | ||||
|     LOG(INFO) << logPrefix() | ||||
|               << "found async exception when checking for NCCL errors: " | ||||
|               << getExceptionMsgFromExceptionPtr(exception_); | ||||
|   } | ||||
| } | ||||
|  | ||||
| const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { | ||||
|   static std::string prefix = c10::str("[Rank ", rank_, "] "); | ||||
|   return prefix; | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::WorkNCCL::setException( | ||||
|     std::exception_ptr exception_ptr) { | ||||
|   std::unique_lock<std::mutex> lock(mutex_); | ||||
| @ -527,9 +586,7 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout( | ||||
|     return true; | ||||
|  | ||||
|   std::string exceptionMsg = c10::str( | ||||
|       "[Rank ", | ||||
|       rank_, | ||||
|       "] ", | ||||
|       logPrefix(), | ||||
|       "Watchdog caught collective operation timeout: ", | ||||
|       *this, | ||||
|       " ran for ", | ||||
| @ -550,13 +607,13 @@ void ProcessGroupNCCL::WorkNCCL::handleException( | ||||
|         "Some NCCL operations have failed or timed out. Due to the ", | ||||
|         "asynchronous nature of CUDA kernels, subsequent GPU operations ", | ||||
|         "might run on corrupted/incomplete data."); | ||||
|     LOG(ERROR) << exceptionMsg; | ||||
|     LOG(ERROR) << logPrefix() << exceptionMsg; | ||||
|     C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException"); | ||||
|  | ||||
|     if (SHOULD_TEAR_DOWN(errorHandling)) { | ||||
|       auto tearDownMsg = c10::str( | ||||
|           "To avoid data inconsistency, we are taking the entire process down."); | ||||
|       LOG(ERROR) << tearDownMsg; | ||||
|       LOG(ERROR) << logPrefix() << tearDownMsg; | ||||
|       std::rethrow_exception(exception_); | ||||
|     } | ||||
|   } | ||||
| @ -597,9 +654,8 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( | ||||
|       // can not run new events successfully. | ||||
|       if (timedOut) { | ||||
|         std::string exceptionMsg = c10::str( | ||||
|             "[Rank ", | ||||
|             rank_, | ||||
|             "] Work ", | ||||
|             logPrefix(), | ||||
|             "Work ", | ||||
|             (*this), | ||||
|             " timed out in blocking wait (TORCH_NCCL_BLOCKING_WAIT=1)."); | ||||
|         LOG(ERROR) << exceptionMsg; | ||||
| @ -713,6 +769,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|       ValueError, | ||||
|       at::cuda::getNumGPUs() != 0, | ||||
|       "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); | ||||
|   logPrefix_ = createLogPrefix(); | ||||
|   blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); | ||||
|   asyncErrorHandling_ = static_cast<ErrorHandlingMode>( | ||||
|       getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); | ||||
| @ -723,8 +780,18 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|   heartbeat_ = 1ULL; | ||||
|   monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true)); | ||||
|   heartbeatTimeoutInSec_ = | ||||
|       getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 2 /*2 Mins*/); | ||||
|       getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/); | ||||
|   waitTimeoutDumpInMilSec_ = | ||||
|       getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 2000); | ||||
|   coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); | ||||
|   ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0); | ||||
|   // store_ usually is wrapped with PrefixStore and the prefix is different | ||||
|   // across different ProcessGroupNCCL(PG) instances. We need to get the | ||||
|   // underlying non-PrefixStore for sharing global information shared across | ||||
|   // different PGs. | ||||
|   PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get()); | ||||
|   globalStore_ = | ||||
|       prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; | ||||
| #ifdef ENABLE_NCCL_ERROR_CHECKING | ||||
|   enableTiming_.store( | ||||
|       getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); | ||||
| @ -737,15 +804,15 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|           expandable_segments()) { | ||||
|     useTensorRegisterAllocatorHook_ = false; | ||||
|     LOG(INFO) | ||||
|         << "[Rank " << rank_ | ||||
|         << "] disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; | ||||
|         << logPrefix() | ||||
|         << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   if (blockingWait_) { | ||||
|     if (asyncErrorHandling_ != NoHandling || desyncDebug_) { | ||||
|       LOG(INFO) | ||||
|           << "[Rank " << rank_ << "] TORCH_NCCL_BLOCKING_WAIT and " | ||||
|           << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and " | ||||
|           << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG" | ||||
|           << "should not both be enabled. " | ||||
|           << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process."; | ||||
| @ -755,8 +822,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|   } else { | ||||
|     if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { | ||||
|       LOG(INFO) | ||||
|           << "[Rank " << rank_ | ||||
|           << "] TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " | ||||
|           << logPrefix() | ||||
|           << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " | ||||
|           << "must both be enabled. " | ||||
|           << "Enabling TORCH_NCCL_ASYNC_ERROR_HANDLING."; | ||||
|       asyncErrorHandling_ = SkipCleanUp; | ||||
| @ -781,12 +848,13 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|   const std::string OFF = "OFF"; | ||||
|   std::string torch_distributed_debug = | ||||
|       getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); | ||||
|   std::string nccl_debug = getCvarString({"NCCL_DEBUG"}, OFF.c_str()); | ||||
|   LOG(INFO) << "[Rank " << rank_ | ||||
|             << "] ProcessGroupNCCL initialization options: " | ||||
|             << "NCCL version: " << getNcclVersion() | ||||
|   LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: " | ||||
|             << "NCCL version: " << getNcclVersion() << ", size: " << size | ||||
|             << ", global rank: " << globalRank() | ||||
|             << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ | ||||
|             << ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeout_ | ||||
|             << ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: " | ||||
|             << waitTimeoutDumpInMilSec_ | ||||
|             << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ | ||||
|             << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() | ||||
|             << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ | ||||
| @ -804,7 +872,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|             << monitorThreadEnabled_.load() | ||||
|             << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ | ||||
|             << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ | ||||
|             << ", NCCL_DEBUG: " << nccl_debug << ", ID=" << this->getID(); | ||||
|             << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ | ||||
|             << ", ID=" << this->getID(); | ||||
|  | ||||
|   RECORD_PARAM_COMMS( | ||||
|       0, // seq | ||||
| @ -835,7 +904,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
| void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { | ||||
|   std::vector<at::Device> rankDevices = {device}; | ||||
|   const auto key = getKeyFromDevices(rankDevices); | ||||
|   LOG(INFO) << "Eagerly connecting nccl backend with device " << device; | ||||
|   LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " | ||||
|             << device; | ||||
|   getNCCLComm(key, rankDevices, OpType::ALLREDUCE); | ||||
| } | ||||
|  | ||||
| @ -846,8 +916,8 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { | ||||
| #ifdef NCCL_HAS_COMM_SPLIT | ||||
|   std::vector<at::Device> rankDevices = {device}; | ||||
|   const auto key = getKeyFromDevices(rankDevices); | ||||
|   LOG(INFO) << "Performing nocolor split on backend device " << device | ||||
|             << ", key " << key << ", i am " << this; | ||||
|   LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " | ||||
|             << device << ", key " << key << ", i am " << this; | ||||
|   auto comm = getNCCLComm(key, rankDevices, OpType::ALLREDUCE); | ||||
|   TORCH_CHECK_WITH( | ||||
|       DistBackendError, | ||||
| @ -897,8 +967,7 @@ void ProcessGroupNCCL::runHealthCheck() { | ||||
|   // We don't need to join the thread, just need to verify health check via the | ||||
|   // CV. Hence we detach the thread here. | ||||
|   t.detach(); // NOLINT | ||||
|   LOG(INFO) << "[Rank " << rank_ << "]" | ||||
|             << " will wait up to " << options_->timeout.count() | ||||
|   LOG(INFO) << logPrefix() << "will wait up to " << options_->timeout.count() | ||||
|             << " msec for NCCL health check to complete."; | ||||
|   std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex); | ||||
|   healthCheckData.healthCheckCv.wait_for( | ||||
| @ -1002,10 +1071,49 @@ std::future<bool> ProcessGroupNCCL::launchAsyncDebugDump() { | ||||
|   return resultFuture; | ||||
| } | ||||
|  | ||||
| void abortCommsFromMap( | ||||
| std::chrono::time_point<std::chrono::steady_clock> getWakeupTime( | ||||
|     int intervalInMilSec) { | ||||
|   return std::chrono::steady_clock::now() + | ||||
|       std::chrono::milliseconds(intervalInMilSec); | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::waitForDumpOrTimeout( | ||||
|     std::future<bool>& fut, | ||||
|     const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime, | ||||
|     size_t timeout_sec) { | ||||
|   TORCH_CHECK(fut.valid(), "Expected a valid future"); | ||||
|  | ||||
|   auto futStatus = fut.wait_for(std::chrono::seconds(timeout_sec)); | ||||
|   TORCH_CHECK( | ||||
|       futStatus != std::future_status::deferred, "Expected eager launch."); | ||||
|   if (futStatus == std::future_status::ready) { | ||||
|     // Calling .get() will re-raise any exception from the future, and we don't | ||||
|     // care about the retval | ||||
|     try { | ||||
|       fut.get(); | ||||
|       std::this_thread::sleep_until(wakeUpTime); | ||||
|     } catch (const std::exception& e) { | ||||
|       LOG(ERROR) << logPrefix() | ||||
|                  << "Caught exception during async debug dump: \"" << e.what() | ||||
|                  << "\"\n"; | ||||
|     } catch (...) { | ||||
|       LOG(ERROR) << logPrefix() | ||||
|                  << "Caught unknown exception during async debug dump."; | ||||
|     } | ||||
|   } else { | ||||
|     LOG(INFO) | ||||
|         << logPrefix() << "Debug dump timed out and is being abandoned." | ||||
|         << " This may be due to slow ADDR2LINE performance processing stacktraces." | ||||
|         << " Try TORCH_DISABLE_ADDR2LINE=1 and TORCH_NCCL_TRACE_CPP_STACK=0 to work around."; | ||||
|   } | ||||
|   // Ensure we sleep at least until wakeUpTime regardless of future execution | ||||
|   // time | ||||
|   std::this_thread::sleep_until(wakeUpTime); | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::abortCommsFromMap( | ||||
|     std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>& | ||||
|         ncclCommsMap, | ||||
|     const int rank, | ||||
|     c10::optional<std::string> abortReason) { | ||||
|   // The process may control multiple devices, loop through the communicators on | ||||
|   // each device | ||||
| @ -1014,6 +1122,8 @@ void abortCommsFromMap( | ||||
|     auto& ncclComms = it.second; | ||||
|  | ||||
|     for (const auto& ncclComm : ncclComms) { | ||||
|       LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " | ||||
|                 << ncclComm->ncclComm_ << " on CUDA device: " << devName; | ||||
|       ncclComm->ncclCommAbort(abortReason); | ||||
|     } | ||||
|     // Note that we don't remove the aborted communicators from the | ||||
| @ -1026,8 +1136,18 @@ void abortCommsFromMap( | ||||
|     // their responsibility to destroy the process group and recreate | ||||
|     // it to recover from errors. | ||||
|  | ||||
|     LOG(INFO) << "[Rank " << rank << "] Destroyed " << ncclComms.size() | ||||
|               << "communicators on CUDA device " << devName; | ||||
|     c10::StreamId streamId = -1; | ||||
|     if (ncclStreams_.find(devName) != ncclStreams_.end()) { | ||||
|       auto streams = ncclStreams_.at(devName); | ||||
|       if (streams.size() > 0) { | ||||
|         streamId = streams[0].id(); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed " | ||||
|               << ncclComms.size() | ||||
|               << " communicators on CUDA device: " << devName | ||||
|               << " with stream: " << streamId; | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -1048,8 +1168,8 @@ void ProcessGroupNCCL::abort(c10::optional<std::string> abortReason) { | ||||
|   ncclCommDevIdxMapMutex.unlock(); | ||||
|  | ||||
|   std::lock_guard<std::mutex> lock(mutex_); | ||||
|   abortCommsFromMap(devNCCLCommMap_, rank_, abortReason); | ||||
|   abortCommsFromMap(inInitializationCommMap_, rank_, abortReason); | ||||
|   abortCommsFromMap(devNCCLCommMap_, abortReason); | ||||
|   abortCommsFromMap(inInitializationCommMap_, abortReason); | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::shutdown() { | ||||
| @ -1067,6 +1187,7 @@ void ProcessGroupNCCL::shutdown() { | ||||
| } | ||||
|  | ||||
| ProcessGroupNCCL::~ProcessGroupNCCL() { | ||||
|   LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; | ||||
|   terminateProcessGroup_.store(true); | ||||
|   workMetaListCV_.notify_one(); | ||||
|  | ||||
| @ -1074,6 +1195,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { | ||||
|   if (ncclCommWatchdogThread_.joinable()) { | ||||
|     ncclCommWatchdogThread_.join(); | ||||
|   } | ||||
|   LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; | ||||
| #endif | ||||
|  | ||||
|   if (onCompletionHookThread_.joinable()) | ||||
| @ -1083,6 +1205,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { | ||||
|   // threads dying due to aborted communicator and raising a SIGABRT | ||||
|   std::string abortReason = c10::str("Process Group destroyed on rank ", rank_); | ||||
|   abort(abortReason); | ||||
|   LOG(INFO) << logPrefix() << "ProcessGroupNCCL abort finished."; | ||||
|  | ||||
|   // We need to wait for abort to finish before we can safely shut down | ||||
|   // heartbeat monitoring thread. | ||||
| @ -1095,33 +1218,20 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { | ||||
| #endif | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::registerDebugInfoWriter( | ||||
|     std::unique_ptr<DebugInfoWriter> writer) { | ||||
|   TORCH_CHECK_WITH( | ||||
|       DistBackendError, | ||||
|       debugInfoWriter_ == nullptr, | ||||
|       "ProcessGroupNCCL debugInfoWriter already registered"); | ||||
|   debugInfoWriter_ = std::move(writer); | ||||
| } | ||||
|  | ||||
| bool ProcessGroupNCCL::dumpDebuggingInfo() { | ||||
|   // Serialize all calls to this function to avoid corrupting data, but allow | ||||
|   // multiple calls in one runtime. User is responsible for preserving the | ||||
|   // output file from an earlier call before a later call overwrites it. | ||||
|   std::lock_guard<std::mutex> lock(writeDebugInfoMutex_); | ||||
|   LOG(ERROR) << "ProcessGroupNCCL preparing to dump debug info."; | ||||
|   static std::mutex writeDebugInfoMutex; | ||||
|   std::lock_guard<std::mutex> lock(writeDebugInfoMutex); | ||||
|   LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; | ||||
|   if (ncclTraceBufferSize_ > 0) { | ||||
|     // We dump nccl trace into local disk by default and users can register | ||||
|     // their customized writer by inheriting `DebugInfoWriter` via | ||||
|     // `registerDebugInfoWriter`. | ||||
|     auto ncclTrace = dump_nccl_trace(); | ||||
|     if (debugInfoWriter_ == nullptr) { | ||||
|       // Dump the trace blob into local disk as a fallback. | ||||
|       std::unique_ptr<DebugInfoWriter> debugInfoWriterPtr = | ||||
|           std::make_unique<DebugInfoWriter>(rank_); | ||||
|       registerDebugInfoWriter(std::move(debugInfoWriterPtr)); | ||||
|     } | ||||
|     debugInfoWriter_->write(ncclTrace); | ||||
|     DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); | ||||
|     writer.write(ncclTrace); | ||||
|     return true; | ||||
|   } | ||||
|   return false; | ||||
| @ -1130,48 +1240,133 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { | ||||
| void ProcessGroupNCCL::terminateProcess(std::string errMsg) { | ||||
|   // Logging with `FATAL`, after errMsg printed, it calls `std::abort()` | ||||
|   // to terminate the program execution. | ||||
|   LOG(FATAL) << errMsg; | ||||
|   LOG(FATAL) << logPrefix() << errMsg; | ||||
| } | ||||
|  | ||||
| int computeDeltaMS( | ||||
|     std::chrono::time_point<std::chrono::steady_clock> start, | ||||
|     std::chrono::time_point<std::chrono::steady_clock> end) { | ||||
|   return std::chrono::duration_cast<std::chrono::milliseconds>(end - start) | ||||
|       .count(); | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::heartbeatMonitor() { | ||||
|   uint64_t heartBeatCounter = 0ULL; | ||||
|   std::string errorMsg; | ||||
|   std::string exitMsg; | ||||
|   bool checkTimeoutSignal = (dumpOnTimeout_ && uid_ == 0); | ||||
|   int monitorPollInterval = checkTimeoutSignal ? coordCheckIntervalMilSec_ | ||||
|                                                : heartbeatTimeoutInSec_ * 1000; | ||||
|   auto lastTimePollStore = std::chrono::steady_clock::now(); | ||||
|   auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); | ||||
|   std::future<bool> asyncDebugDump; | ||||
|   while (true) { | ||||
|     // This won't have any lock since this lock is only used here. | ||||
|     // Please be aware that mutex `monitorMutex_` should not be used | ||||
|     // somewhere else to avoid the deadlock. | ||||
|     std::unique_lock<std::mutex> lock(monitorMutex_); | ||||
|     if (monitorWakeUpCV_.wait_for( | ||||
|             lock, std::chrono::seconds(heartbeatTimeoutInSec_), [&] { | ||||
|             lock, std::chrono::milliseconds(monitorPollInterval), [&] { | ||||
|               return terminateHeartbeatMonitorThread_.load(); | ||||
|             })) { | ||||
|       // For the normal complete or user interception, monitorWakeUpCV_ | ||||
|       // will get notified, we early return and exit heartbeatMonitor. | ||||
|       return; | ||||
|     } | ||||
|     auto currentTime = std::chrono::steady_clock::now(); | ||||
|  | ||||
|     // Check the heart beat of watchdog thread. | ||||
|     auto heartbeat = heartbeat_; | ||||
|     if (heartbeat != heartBeatCounter) { | ||||
|       heartBeatCounter = heartbeat; | ||||
|     } else { | ||||
|       // No heartbeat increase detected and timeout. | ||||
|       break; | ||||
|     // We put extra functionality in the thread for the default PG (aka, uid_=0) | ||||
|     // because the signal is same across different PGs. We only need to run | ||||
|     // once per process to avoid duplicate things performed in too many separate | ||||
|     // threads. For example, we check a global flag on the TCPStore periodically | ||||
|     // to see if any PG on any rank observed a timeout and signaled peers to | ||||
|     // dump debugging info, and we avoid hammering the TCPStore from all PGs on | ||||
|     // the same rank. | ||||
|     if (checkTimeoutSignal) { | ||||
|       // We poll store to see if some ranks have flagged a timeout when | ||||
|       // we haven't polled for `heartbeat_timeout` seconds and there haven't | ||||
|       // any work added or removed for `watchdog_timeout` seconds. | ||||
|       if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >= | ||||
|               kWatchdogThreadSleepMillis && | ||||
|           computeDeltaMS(lastTimePollStore, currentTime) >= | ||||
|               coordCheckIntervalMilSec_) { | ||||
|         lastTimePollStore = currentTime; | ||||
|         if (globalStore_->check({std::string(TIMEOUT_DUMP)})) { | ||||
|           errorMsg = c10::str( | ||||
|               logPrefix(), | ||||
|               "Received a global timeout from another rank and will ", | ||||
|               "start to dump the debug info."); | ||||
|           exitMsg = c10::str( | ||||
|               "ProcessGroupNCCL's watchdog detected a collective timeout and notified current rank. ", | ||||
|               "This is most likely caused by incorrect usages of collectives, e.g., wrong ", | ||||
|               "sizes used across ranks, the order of collectives is not same for all ranks ", | ||||
|               "or the scheduled collective, for some reason, didn't run. Additionally, ", | ||||
|               "this can be caused by GIL deadlock or other reasons such as network errors or ", | ||||
|               "bugs in the communications library (e.g. NCCL), etc. We tried our best to ", | ||||
|               "dump the debug info into the storage to help you debug the issue."); | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >= | ||||
|         heartbeatTimeoutInSec_ * 1000) { | ||||
|       // Check the heart beat of watchdog thread. | ||||
|       lastTimeHeartBeatCheck = currentTime; | ||||
|       auto heartbeat = heartbeat_.load(); | ||||
|       if (heartbeat != heartBeatCounter) { | ||||
|         heartBeatCounter = heartbeat; | ||||
|       } else { | ||||
|         // No heartbeat increase detected and timeout. | ||||
|         errorMsg = c10::str( | ||||
|             logPrefix(), | ||||
|             "Heartbeat monitor timed out! Process will be terminated after dumping debug info.", | ||||
|             " workMetaList_.size()=", | ||||
|             workMetaList_.size()); | ||||
|         exitMsg = c10::str( | ||||
|             "ProcessGroupNCCL's watchdog got stuck for ", | ||||
|             heartbeatTimeoutInSec_, | ||||
|             " seconds without making progress in monitoring enqueued collectives. ", | ||||
|             "This typically indicates a NCCL/CUDA API hang blocking the watchdog, ", | ||||
|             "and could be triggered by another thread holding the GIL inside a ", | ||||
|             "CUDA api, or other deadlock-prone behaviors.", | ||||
|             "If you suspect the watchdog is not actually stuck and a longer timeout would help, ", | ||||
|             "you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value " | ||||
|             "or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0)." | ||||
|             "If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout " | ||||
|             "or false positive abort; otherwise, please attempt to debug the hang."); | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   LOG(ERROR) << errorMsg; | ||||
|  | ||||
|   auto& cpp_dumper = get_cpp_trace_dumper(); | ||||
|   if (cpp_dumper.has_value()) { | ||||
|     LOG(INFO) << "Dumping c++ stacktraces: " << cpp_dumper.value()(); | ||||
|   } | ||||
|  | ||||
|   auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_); | ||||
|   // Store debug info to storage if no other thread does it. (By default to | ||||
|   // local disk) | ||||
|   std::future<bool> asyncDebugDump = launchAsyncDebugDump(); | ||||
|   asyncDebugDump = launchAsyncDebugDump(); | ||||
|  | ||||
|   // Create a error message reported from MonitorThread, so | ||||
|   // we throw exception and make the whole process to be killed. | ||||
|   const auto exitMsg = c10::str( | ||||
|       "[Rank ", | ||||
|       rank_, | ||||
|       "] NCCL monitor thread timeout. Basically, this could ", | ||||
|       "be due to CUDA or NCCL calls being unexpectedly blocking, ", | ||||
|       "especially when your program enters a deadlock state in watchdog " | ||||
|       "or destructors. If you see this error, please file a bug to PyTorch."); | ||||
|   if (get_gil_checker() != nullptr) { | ||||
|     auto fut = launchAsyncGilCheck(); | ||||
|     auto kGilCheckTimeout = std::chrono::milliseconds(300); | ||||
|     auto futStatus = fut.wait_for(kGilCheckTimeout); | ||||
|     if (futStatus != std::future_status::ready) { | ||||
|       TORCH_CHECK( | ||||
|           futStatus != std::future_status::deferred, | ||||
|           "Expected the future to have been launched eagerly."); | ||||
|       LOG(ERROR) | ||||
|           << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang"; | ||||
|     } | ||||
|     LOG(INFO) << "Could acquire GIL on exit"; | ||||
|   } else { | ||||
|     LOG(INFO) | ||||
|         << "GIL checker was not registered, perhaps this is a no-python build?"; | ||||
|   } | ||||
|  | ||||
|   // There are two possible cases for the watchdog thread exit: | ||||
|   // Case one: desync report runs quickly, and it follows the step: | ||||
| @ -1197,46 +1392,40 @@ void ProcessGroupNCCL::heartbeatMonitor() { | ||||
|   // We already log completion inside the thread, so it may not be necessary to | ||||
|   // check the return value here.  We mainly use a future so we can exit early | ||||
|   // if done. | ||||
|   asyncDebugDump.wait_for(std::chrono::seconds(heartbeatTimeoutInSec_)); | ||||
|   waitForDumpOrTimeout(asyncDebugDump, wakeUpTime); | ||||
|  | ||||
|   if (!terminateHeartbeatMonitorThread_.load()) { | ||||
|     const auto logMsg = c10::str( | ||||
|         "[Rank ", | ||||
|         rank_, | ||||
|         "] monitoring thread detects no heartbeat and will finally kill the process!", | ||||
|         " terminateProcessGroup_", | ||||
|         terminateProcessGroup_, | ||||
|         " collectiveDebugInfoMode_", | ||||
|         collectiveDebugInfoMode_); | ||||
|     LOG(ERROR) << logMsg; | ||||
|     terminateProcess(exitMsg); | ||||
|     // Create a error message reported from MonitorThread, so | ||||
|     // we throw exception and make the whole process to be killed. | ||||
|     // TODO(fduwjj): After having a hang debug wiki, we need to update the wiki | ||||
|     // url here. | ||||
|     const auto finalExitMsg = c10::str(logPrefix(), exitMsg); | ||||
|     terminateProcess(finalExitMsg); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::ncclCommWatchdog() { | ||||
|   try { | ||||
|     VLOG(2) << "[Rank " << rank_ << "] NCCL watchdog thread started!"; | ||||
|     VLOG(2) << logPrefix() << "NCCL watchdog thread started!"; | ||||
|     if (monitorThreadEnabled_.load()) { | ||||
|       ncclHeartbeatMonitorThread_ = | ||||
|           std::thread(&ProcessGroupNCCL::heartbeatMonitor, this); | ||||
|     } | ||||
|     watchdogHandler(); | ||||
|     VLOG(2) << "[Rank " << rank_ | ||||
|             << "] NCCL watchdog thread terminated normally"; | ||||
|     VLOG(2) << logPrefix() << "NCCL watchdog thread terminated normally"; | ||||
|   } catch (std::exception& e) { | ||||
|     if (std::string(e.what()).find("driver shutting down") != | ||||
|         std::string::npos) { | ||||
|       LOG(INFO) | ||||
|           << "[Rank " << rank_ | ||||
|           << "] main process destroyed cuda before watchdog loop exited, terminating watchdog." | ||||
|           << logPrefix() | ||||
|           << "main process destroyed cuda before watchdog loop exited, terminating watchdog." | ||||
|           << " (Watchdog caught exception: " << e.what(); | ||||
|  | ||||
|     } else { | ||||
|       // Append error message reported from watchdogHandler | ||||
|       const auto exitMsg = c10::str( | ||||
|           "[Rank ", | ||||
|           rank_, | ||||
|           "] NCCL watchdog thread terminated with exception: ", | ||||
|           logPrefix(), | ||||
|           "NCCL watchdog thread terminated with exception: ", | ||||
|           e.what()); | ||||
|       LOG(ERROR) << exitMsg; | ||||
|       // TODO(whc) clean up the rethrow - why is it stored in a class var and | ||||
| @ -1247,9 +1436,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() { | ||||
|     } | ||||
|   } catch (...) { | ||||
|     const auto exitMsg = c10::str( | ||||
|         "[Rank ", | ||||
|         rank_, | ||||
|         "] NCCL watchdog thread terminated with exception: unknown"); | ||||
|         logPrefix(), "NCCL watchdog thread terminated with exception: unknown"); | ||||
|     LOG(ERROR) << exitMsg; | ||||
|     watchDogException_ = | ||||
|         std::make_exception_ptr(C10_BUILD_ERROR(DistBackendError, exitMsg)); | ||||
| @ -1288,12 +1475,13 @@ std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() { | ||||
|  | ||||
| #if defined(__linux__) | ||||
| struct DumpPipe { | ||||
|   DumpPipe(bool enabled, int rank) { | ||||
|     if (!enabled) { | ||||
|   DumpPipe(int rank) { | ||||
|     std::string fileStem = | ||||
|         getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, ""); | ||||
|     if (fileStem.empty() || | ||||
|         getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) { | ||||
|       return; | ||||
|     } | ||||
|     std::string fileStem = getCvarString( | ||||
|         {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|     TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty"); | ||||
|     std::string filename = c10::str(fileStem, rank, ".pipe"); | ||||
|     TORCH_CHECK( | ||||
| @ -1305,6 +1493,8 @@ struct DumpPipe { | ||||
|         "Error creating named pipe ", | ||||
|         filename); | ||||
|     fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); | ||||
|     LOG(INFO) << "Pipe file " << filename | ||||
|               << " has been opened, write to it to trigger NCCL Debug Dump."; | ||||
|     TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename); | ||||
|   } | ||||
|   bool shouldDump() { | ||||
| @ -1329,22 +1519,40 @@ struct DumpPipe { | ||||
| }; | ||||
| #else | ||||
| struct DumpPipe { | ||||
|   DumpPipe(bool enabled, int rank) {} | ||||
|   DumpPipe(int rank) {} | ||||
|   bool shouldDump() { | ||||
|     return false; | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| std::string ProcessGroupNCCL::createLogPrefix() const { | ||||
|   return c10::str("[PG ", uid_, " Rank ", rank_, "] "); | ||||
| } | ||||
|  | ||||
| const std::string& ProcessGroupNCCL::logPrefix() const { | ||||
|   return logPrefix_; | ||||
| } | ||||
|  | ||||
| const int& ProcessGroupNCCL::globalRank() const { | ||||
|   static int globalRank = rank_; | ||||
|   return globalRank; | ||||
| } | ||||
|  | ||||
| void ProcessGroupNCCL::watchdogHandler() { | ||||
|   bool done = false; | ||||
|   lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); | ||||
|   auto lastTimePollStore = std::chrono::steady_clock::now(); | ||||
|   c10::optional<std::future<bool>> optAsyncDebugDump; | ||||
|  | ||||
|   std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList; | ||||
|  | ||||
|   DumpPipe dumpPipe(dumpOnTimeout_, rank_); | ||||
|   c10::optional<DumpPipe> dumpPipe = c10::nullopt; | ||||
|   if (uid_ == 0) { | ||||
|     // DumpPipe is one per-trainer process, and its convenient to name them | ||||
|     // after 'global' ranks in the system, So we assume processgroup (uid)==0 is | ||||
|     // the global PG and has globally unique rank ids across trainers. | ||||
|     dumpPipe.emplace(rank_); | ||||
|   } | ||||
|   while (!done || !terminateProcessGroup_.load()) { | ||||
|     std::unique_lock<std::mutex> lock(workMetaListMutex_); | ||||
|     // We busy-poll the work vector every kWatchdogThreadSleepMillis | ||||
| @ -1356,37 +1564,6 @@ void ProcessGroupNCCL::watchdogHandler() { | ||||
|     // Bump up heart beat by one. | ||||
|     heartbeat_++; | ||||
|  | ||||
|     // poll store to see if some ranks have flagged a timeout when | ||||
|     // we haven't polled for `heartbeat_timeout` seconds and there haven't | ||||
|     // any work added or removed for `watchdog_timeout` seconds. | ||||
|     if (dumpOnTimeout_) { | ||||
|       auto currentTime = std::chrono::steady_clock::now(); | ||||
|       auto timeSinceLastWorkListUpdate = | ||||
|           std::chrono::duration_cast<std::chrono::milliseconds>( | ||||
|               (currentTime - lastWorkListUpdateTime_)) | ||||
|               .count(); | ||||
|       auto timeSinceLastPollStore = | ||||
|           std::chrono::duration_cast<std::chrono::milliseconds>( | ||||
|               (currentTime - lastTimePollStore)) | ||||
|               .count(); | ||||
|       if (timeSinceLastWorkListUpdate >= kWatchdogThreadSleepMillis && | ||||
|           timeSinceLastPollStore >= heartbeatTimeoutInSec_ * 1000) { | ||||
|         lastTimePollStore = currentTime; | ||||
|         if (store_->check({std::string(TIMEOUT_DUMP)}) && !optAsyncDebugDump) { | ||||
|           optAsyncDebugDump = launchAsyncDebugDump(); | ||||
|           optAsyncDebugDump->wait_for( | ||||
|               std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30)); | ||||
|           const auto exitMsg = c10::str( | ||||
|               "Some other rank's watchdog thread detected a timeout and notified ", | ||||
|               "all other ranks, so we're dumping debug info and aborting [Rank ", | ||||
|               rank_, | ||||
|               "] as well."); | ||||
|           LOG(ERROR) << exitMsg; | ||||
|           C10_THROW_ERROR(DistBackendError, exitMsg); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     for (auto it = workMetaList_.begin(); it != workMetaList_.end(); | ||||
|          /* no increment */) { | ||||
|       auto& work = *it; | ||||
| @ -1411,9 +1588,10 @@ void ProcessGroupNCCL::watchdogHandler() { | ||||
|               // abort process immediately. | ||||
|               collectiveDebugInfoMode_.store(true); | ||||
|               std::vector<uint8_t> vec(1); | ||||
|               store_->set(std::string(TIMEOUT_DUMP), vec); | ||||
|               globalStore_->set(std::string(TIMEOUT_DUMP), vec); | ||||
|             } | ||||
|  | ||||
|             auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_); | ||||
|             if (dumpOnTimeout_ && !optAsyncDebugDump) { | ||||
|               // Store debug info to storage. (By default to local disk) | ||||
|               optAsyncDebugDump = launchAsyncDebugDump(); | ||||
| @ -1421,20 +1599,21 @@ void ProcessGroupNCCL::watchdogHandler() { | ||||
|  | ||||
|             if (desyncDebug_) { | ||||
|               auto desyncMsg = getNCCLWatchdogDebugInfo(); | ||||
|               LOG(ERROR) << desyncMsg; | ||||
|               LOG(ERROR) << logPrefix() << desyncMsg; | ||||
|             } | ||||
|  | ||||
|             if (dumpOnTimeout_) { | ||||
|               // Store debug info to storage. (By default to local disk) | ||||
|               optAsyncDebugDump->wait_for( | ||||
|                   std::chrono::milliseconds(kWatchdogThreadSleepMillis * 30)); | ||||
|               waitForDumpOrTimeout(*optAsyncDebugDump, wakeUpTime); | ||||
|             } | ||||
|  | ||||
|           } catch (const std::exception& e) { | ||||
|             LOG(ERROR) << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " | ||||
|             LOG(ERROR) << logPrefix() | ||||
|                        << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " | ||||
|                        << " Please file an issue. Error: " << e.what(); | ||||
|           } catch (...) { | ||||
|             LOG(ERROR) | ||||
|                 << logPrefix() | ||||
|                 << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." | ||||
|                 << " Please file an issue."; | ||||
|           } | ||||
| @ -1455,7 +1634,7 @@ void ProcessGroupNCCL::watchdogHandler() { | ||||
|  | ||||
|       // Clean up completed work | ||||
|       if (work.isCompleted()) { | ||||
|         NCCLTraceBuffer::get()->retire_id(work.trace_id_); | ||||
|         NCCLTraceBuffer::get()->retire_id(work.trace_id_, true); | ||||
|         if (onCompletionHook_) { | ||||
|           // Move Work object to completedWorkList_ to be consumed by the hook | ||||
|           // thread | ||||
| @ -1475,9 +1654,14 @@ void ProcessGroupNCCL::watchdogHandler() { | ||||
|         // completed. | ||||
|         ++it; | ||||
|       } | ||||
|       // Increment heartbeat after each work processed, | ||||
|       // in case processing is slowed down (but not hung) by cuda api contention | ||||
|       heartbeat_++; | ||||
|     } | ||||
|     // process a request to dump the trace | ||||
|     if (dumpPipe.shouldDump()) { | ||||
|     // process a request to dump the trace. only PG uid 0 will respond to dump | ||||
|     // requests, but this is fine since all PG's feed into the same flight | ||||
|     // recorder and dump. | ||||
|     if (dumpPipe.has_value() && dumpPipe->shouldDump()) { | ||||
|       launchAsyncDebugDump(); | ||||
|     } | ||||
|     done = workMetaList_.empty(); | ||||
| @ -1523,8 +1707,8 @@ void ProcessGroupNCCL::runHookLoop() { | ||||
|       if (std::string(e.what()).find("driver shutting down") != | ||||
|           std::string::npos) { | ||||
|         LOG(INFO) | ||||
|             << "[Rank " << rank_ | ||||
|             << "] main process destroyed cuda before runHookLoop exited, terminating runHookLoop." | ||||
|             << logPrefix() | ||||
|             << "main process destroyed cuda before runHookLoop exited, terminating runHookLoop." | ||||
|             << " (runHookLoop caught exception: " << e.what(); | ||||
|  | ||||
|       } else { | ||||
| @ -1698,7 +1882,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm( | ||||
|   if (bound_device_id_) { | ||||
|     for (const auto& device : devices) { | ||||
|       if (*bound_device_id_ != device) { | ||||
|         LOG(ERROR) << "Tensor found on device " << device | ||||
|         LOG(ERROR) << logPrefix() << "Tensor found on device " << device | ||||
|                    << " but backend constrained to " << *bound_device_id_; | ||||
|         C10_THROW_ERROR( | ||||
|             DistBackendError, | ||||
| @ -1848,11 +2032,16 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm( | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   for (const auto i : c10::irange(devices.size())) { | ||||
|     int deviceIndex = devices[i].index(); | ||||
|     LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ " | ||||
|               << ncclComms[i]->ncclComm_ << " on CUDA device: " << deviceIndex; | ||||
|   } | ||||
|  | ||||
|   // At this point NCCL should have been initialized, hence we can accurately | ||||
|   // get the env value even if NCCL sets it by reading from nccl.conf file | ||||
|   if (getRank() == 0) { | ||||
|     LOG(INFO) << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A"); | ||||
|   } | ||||
|   LOG(INFO) << logPrefix() | ||||
|             << "NCCL_DEBUG: " << getCvarString({"NCCL_DEBUG"}, "N/A"); | ||||
|  | ||||
|   // See [Group Start/End Note] | ||||
|   for (const auto i : c10::irange(ncclActiveGroupCounter_)) { | ||||
| @ -1897,18 +2086,17 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm( | ||||
|               segmentInfo.total_size); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Record the mapping between ncclComm and device index so that later | ||||
|       // register hook can register a newly allocated segment to communicators | ||||
|       // on the same device. | ||||
|       // NOTE: we need remove the communicator from this map when it is | ||||
|       // destroyed, otherwise may register onto an invalid communicator. | ||||
|       ncclCommDevIdxMapMutex.lock(); | ||||
|       for (const auto i : c10::irange(devices.size())) { | ||||
|         ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index()); | ||||
|       } | ||||
|       ncclCommDevIdxMapMutex.unlock(); | ||||
|     } | ||||
|     // Record the mapping between ncclComm and device index so that later | ||||
|     // register hook can register a newly allocated segment to communicators | ||||
|     // on the same device. | ||||
|     // NOTE: we need remove the communicator from this map when it is | ||||
|     // destroyed, otherwise may register onto an invalid communicator. | ||||
|     ncclCommDevIdxMapMutex.lock(); | ||||
|     for (const auto i : c10::irange(devices.size())) { | ||||
|       ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index()); | ||||
|     } | ||||
|     ncclCommDevIdxMapMutex.unlock(); | ||||
|   } | ||||
|  | ||||
|   it = devNCCLCommMap_.find(devicesKey); | ||||
| @ -2130,7 +2318,8 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork( | ||||
|   r->trace_id_ = NCCLTraceBuffer::get()->record( | ||||
|       uid_, | ||||
|       seq_, | ||||
|       profilingTitle, | ||||
|       // create a string copy of profilingTitle | ||||
|       profilingTitle ? profilingTitle : "", | ||||
|       inputs, | ||||
|       outputs, | ||||
|       r->ncclStartEvents_.get(), | ||||
| @ -2148,18 +2337,16 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL:: | ||||
| } | ||||
|  | ||||
| float ProcessGroupNCCL::WorkNCCL::getDuration() const { | ||||
|   TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled") | ||||
|   TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); | ||||
|   TORCH_CHECK( | ||||
|       ncclStartEvents_->size() == 1, | ||||
|       "getDuration only works for single device per ProcessGroup."); | ||||
|       ncclStartEvents_, | ||||
|       "getDuration only works if ncclStartEvents_ is populated, true if timing enabled"); | ||||
|   TORCH_CHECK( | ||||
|       ncclEndEvents_->size() == 1, | ||||
|       "getDuration only works for single device per ProcessGroup."); | ||||
|   TORCH_CHECK( | ||||
|       (*ncclEndEvents_)[0].query(), | ||||
|       "getDuration can only be called after work is succeeded.") | ||||
|   return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]); | ||||
|       ncclEndEvents_, | ||||
|       "getDuration only works if ncclEndEvents_ is populated, which should always be true"); | ||||
|   return getDurationFromFirstEvent(*ncclStartEvents_, *ncclEndEvents_); | ||||
| } | ||||
|  | ||||
| uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const { | ||||
|   return seq_; | ||||
| } | ||||
| @ -3420,14 +3607,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) { | ||||
|     // ensure that each process is on a different GPU | ||||
|     auto numGPUs = at::cuda::getNumGPUs(); | ||||
|     int16_t deviceIdx = static_cast<int16_t>(rank_ % numGPUs); | ||||
|     LOG(INFO) << c10::str( | ||||
|         "Rank ", | ||||
|         this->getRank(), | ||||
|         " using GPU ", | ||||
|         deviceIdx, | ||||
|         " to perform barrier as devices used by this process are currently unknown. ", | ||||
|         "This can potentially cause a hang if this rank to GPU mapping is incorrect.", | ||||
|         "Specify device_ids in barrier() to force use of a particular device."); | ||||
|     LOG(INFO) | ||||
|         << logPrefix() | ||||
|         << c10::str( | ||||
|                " using GPU ", | ||||
|                deviceIdx, | ||||
|                " to perform barrier as devices used by this process are currently unknown. ", | ||||
|                "This can potentially cause a hang if this rank to GPU mapping is incorrect.", | ||||
|                "Specify device_ids in barrier() to force use of a particular device."); | ||||
|     devices.emplace_back(guessDeviceForRank()); | ||||
|   } else { | ||||
|     for (auto usedDeviceIdx : usedDeviceIdxs_) { | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| #ifdef USE_C10D_NCCL | ||||
|  | ||||
| #include <atomic> | ||||
| #include <chrono> | ||||
| #include <future> | ||||
| #include <iostream> | ||||
| @ -12,6 +13,7 @@ | ||||
|  | ||||
| #include <torch/csrc/distributed/c10d/Backend.hpp> | ||||
| #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> | ||||
| #include <torch/csrc/distributed/c10d/PrefixStore.hpp> | ||||
| #include <torch/csrc/distributed/c10d/Store.hpp> | ||||
|  | ||||
| #include <ATen/DynamicLibrary.h> | ||||
| @ -26,49 +28,71 @@ | ||||
| #include <torch/custom_class.h> | ||||
|  | ||||
| namespace c10d { | ||||
| // Environment variable which controls whether we perform a NCCL healt check | ||||
| // Control whether we perform a NCCL health check or not | ||||
| // which ensures communicators are healthy at the beginning of init. | ||||
| static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = { | ||||
|     "TORCH_ENABLE_NCCL_HEALTH_CHECK", | ||||
|     "ENABLE_NCCL_HEALTH_CHECK"}; | ||||
|  | ||||
| // Environment variable which controls whether or not wait() is blocking or | ||||
| // non-blocking. | ||||
| // Control whether or not wait() is blocking or non-blocking. | ||||
| static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = { | ||||
|     "TORCH_NCCL_BLOCKING_WAIT", | ||||
|     "NCCL_BLOCKING_WAIT"}; | ||||
|  | ||||
| // Environment variable which controls whether or not we perform Async Error | ||||
| // Handling with NCCL. | ||||
| // Control whether or not we perform Async Error Handling with NCCL. | ||||
| static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = { | ||||
|     "TORCH_NCCL_ASYNC_ERROR_HANDLING", | ||||
|     "NCCL_ASYNC_ERROR_HANDLING"}; | ||||
|  | ||||
| // Environment Variable to control whether dumping debug info on watchdog | ||||
| // Control whether dumping debug info on watchdog | ||||
| // timeout is enabled. This variable must be set together with | ||||
| // TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0. | ||||
| static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = { | ||||
|     "TORCH_NCCL_DUMP_ON_TIMEOUT"}; | ||||
|  | ||||
| // Environment Variable to control whether Desync Debug is enabled. | ||||
| // This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING. | ||||
| // Control whether Desync Debug is enabled. This variable must be set | ||||
| // together with TORCH_NCCL_ASYNC_ERROR_HANDLING. | ||||
| static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = { | ||||
|     "TORCH_NCCL_DESYNC_DEBUG", | ||||
|     "NCCL_DESYNC_DEBUG"}; | ||||
|  | ||||
| // Enable recording start-events for all ProcessGroupNCCL collectives, and | ||||
| // compute accurate collective timing per-collective. (Note: end-events are | ||||
| // recorded by default. Turn on this flag can increase chances of a watchdog | ||||
| // hang due to performing a CUDA event query which eventually calls | ||||
| // cudaEventElapsedTime() API. | ||||
| static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = { | ||||
|     "TORCH_NCCL_ENABLE_TIMING", | ||||
|     "NCCL_ENABLE_TIMING"}; | ||||
|  | ||||
| // Enable monitoring thread which aborts the process when the ProcessGroupNCCL | ||||
| // Watchdog thread gets stuck and no heartbeat is detected after | ||||
| // TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL | ||||
| // APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged | ||||
| // time than necessary tying up cluster resources. | ||||
| static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = { | ||||
|     "TORCH_NCCL_ENABLE_MONITORING"}; | ||||
|  | ||||
| // Control the watchdog heartbeat timeout period after which the monitoring | ||||
| // thread will abort the process. | ||||
| static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = { | ||||
|     "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"}; | ||||
|  | ||||
| // The maximum number of events we store in the flight recorder's ring buffer. | ||||
| // (One event could be the start or end of a collective, for example). | ||||
| static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = { | ||||
|     "TORCH_NCCL_TRACE_BUFFER_SIZE"}; | ||||
|  | ||||
| // Control how much extra time we will wait for dumping the debugging info | ||||
| // before we exit and throws timeout exception. | ||||
| static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = { | ||||
|     "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"}; | ||||
|  | ||||
| // Control the interval inside the watchdog thread to check the coordinated | ||||
| // signal from other ranks, e.g. to dump the debugging information. | ||||
| static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = { | ||||
|     "TORCH_NCCL_COORD_CHECK_MILSEC"}; | ||||
|  | ||||
| constexpr const char* NCCL_BACKEND_NAME = "nccl"; | ||||
|  | ||||
| constexpr const char* TIMEOUT_DUMP = "timeout_dump"; | ||||
| @ -205,6 +229,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|     uint64_t getSequencenumber() const override; | ||||
|  | ||||
|     const std::string& logPrefix() const; | ||||
|  | ||||
|     // Helper function that sets an exception_ptr on the WorkNCCL object. | ||||
|     void setException(std::exception_ptr exception_ptr); | ||||
|  | ||||
| @ -540,8 +566,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   void enableCollectivesTiming() override; | ||||
|  | ||||
|   // Provide an API for users to define their own ways to store NCCL debug info. | ||||
|   void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer); | ||||
|   // Helper function for iteratively aborting communicators in the provided map | ||||
|   void abortCommsFromMap( | ||||
|       std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>& | ||||
|           ncclCommsMap, | ||||
|       c10::optional<std::string> abortReason); | ||||
|  | ||||
|   // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) | ||||
|   // instead of relying on ProcessGroupNCCL destructor. | ||||
| @ -694,6 +723,19 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   // Desync debug helper | ||||
|   void logWorkEnd(WorkNCCL& work); | ||||
|  | ||||
|   // Generates a prefix that is unique to this process group and rank, for | ||||
|   // disambiguating logs | ||||
|   std::string createLogPrefix() const; | ||||
|  | ||||
|   // Returns the unique prefix created in createLogPrefix | ||||
|   const std::string& logPrefix() const; | ||||
|  | ||||
|   // Returns the global rank of the device. This function assumes that users | ||||
|   // always create a default global process group(PG) which includes all | ||||
|   // devices. It is called in the constructor of ProcessGroupNCCL, so it always | ||||
|   // return the rank_ of the the very first PG created, aka, default global PG. | ||||
|   const int& globalRank() const; | ||||
|  | ||||
|  protected: | ||||
|   // Function that runs as part of a separate thread aside from watchdog | ||||
|   // thread because we need to check the heartbeat from watchdog thread | ||||
| @ -712,6 +754,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   // for dump completion. | ||||
|   std::future<bool> launchAsyncDebugDump(); | ||||
|  | ||||
|   // Helper to wait up to the specified timeout and then abandon the dump. | ||||
|   // Logs on timeout, and asserts the future's status is as expected. | ||||
|   void waitForDumpOrTimeout( | ||||
|       std::future<bool>& fut, | ||||
|       const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime, | ||||
|       size_t timeout_sec = 30); | ||||
|  | ||||
|   // When watchdog timeout, this function will be called and return debug info | ||||
|   // for users. For now we only get information from retrieveDesyncReport. | ||||
|   // We are working on enabling more useful debug information for watchdog | ||||
| @ -720,9 +769,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   static const int64_t kWatchdogThreadSleepMillis; | ||||
|  | ||||
|   // The store is used to broadcast the NCCL unique ID of rank 0. | ||||
|   // The store is used to broadcast the NCCL unique ID of rank 0. This store | ||||
|   // comes with prefix and it is different across ProcessGroup NCCL instances | ||||
|   // (aka, different ProcessGroups). | ||||
|   c10::intrusive_ptr<Store> store_; | ||||
|  | ||||
|   // Reference to the store without prefix so that keys are same across all | ||||
|   // ProcessGroup NCCL instances and (key, value) pairs written to the store are | ||||
|   // global. | ||||
|   c10::intrusive_ptr<Store> globalStore_; | ||||
|  | ||||
|   bool storeError_{false}; | ||||
|  | ||||
|   const c10::intrusive_ptr<Options> options_; | ||||
| @ -781,11 +837,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   std::mutex mutex_; | ||||
|  | ||||
|   // Heartbeat of watchdog thread. | ||||
|   uint64_t heartbeat_; | ||||
|   std::atomic_uint64_t heartbeat_; | ||||
|  | ||||
|   // The time interval used for deciding whether there is no watchdog heartbeat. | ||||
|   int heartbeatTimeoutInSec_; | ||||
|  | ||||
|   // Extra time of sleep when waiting for timeout dump to finish. | ||||
|   int waitTimeoutDumpInMilSec_; | ||||
|  | ||||
|   // Interval of check coordinated signals in ProcessGroupNCCL from other ranks | ||||
|   // e.g., trigger the dump of the debugging info for timeout when notified. | ||||
|   int coordCheckIntervalMilSec_; | ||||
|  | ||||
|   // Size of ring buffer where we store NCCL Traces for debugging. | ||||
|   int ncclTraceBufferSize_; | ||||
|  | ||||
| @ -823,9 +886,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   bool writeDebugInfo_ = false; | ||||
|  | ||||
|   // Mutex to Guard the check of writeDebugInfo_ | ||||
|   std::mutex writeDebugInfoMutex_; | ||||
|  | ||||
|   // Condition Variable for watchdog thread sleep | ||||
|   std::condition_variable workMetaListCV_; | ||||
|  | ||||
| @ -929,14 +989,25 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   std::exception_ptr watchDogException_ = nullptr; | ||||
|  | ||||
|   // The callback function to store NCCL debug info. | ||||
|   std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr; | ||||
|  | ||||
|   size_t uid_; | ||||
|  | ||||
|   std::string logPrefix_; | ||||
| }; | ||||
|  | ||||
| TORCH_API std::string dump_nccl_trace(); | ||||
|  | ||||
| // Gets a mutable reference to a global optional function.  Heartbeat Monitor | ||||
| // will query this function and if available, call it to dump traces. Inside | ||||
| // fbcode, we store a function here that uses an internal tool for process | ||||
| // tracing | ||||
| TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper(); | ||||
|  | ||||
| // Similar to get_cpp_trace_dumper, this stores a function defined in | ||||
| // torch-python layer that lets us check whether the GIL can be acquired, | ||||
| // helpful for instrumenting in cases where a hang was observed. | ||||
| typedef bool (*gil_checker_t)(); | ||||
|  | ||||
| TORCH_API gil_checker_t& get_gil_checker(); | ||||
| } // namespace c10d | ||||
|  | ||||
| #endif // USE_C10D_NCCL | ||||
|  | ||||
| @ -13,7 +13,6 @@ | ||||
| #include <string> | ||||
| #include <system_error> | ||||
| #include <vector> | ||||
|  | ||||
| namespace c10d { | ||||
|  | ||||
| /* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */ | ||||
| @ -269,10 +268,20 @@ inline std::string retrieveDesyncReport( | ||||
|  | ||||
| #ifdef USE_C10D_NCCL | ||||
|  | ||||
| DebugInfoWriter::DebugInfoWriter(int rank) { | ||||
|   std::string fileName = getCvarString( | ||||
|       {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|   filename_ = c10::str(fileName, rank); | ||||
| /* Helper used by work::getDuration() and nccl flight recorder */ | ||||
| float getDurationFromFirstEvent( | ||||
|     const std::vector<at::cuda::CUDAEvent>& ncclStartEvents, | ||||
|     const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) { | ||||
|   TORCH_CHECK( | ||||
|       ncclStartEvents.size() == 1, | ||||
|       "getDuration only works for single device per ProcessGroup, but found multiple start events."); | ||||
|   TORCH_CHECK( | ||||
|       ncclEndEvents.size() == 1, | ||||
|       "getDuration only works for single device per ProcessGroup, but found multiple end events."); | ||||
|   TORCH_CHECK( | ||||
|       ncclEndEvents[0].query(), | ||||
|       "getDuration can only be called after work is succeeded.") | ||||
|   return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]); | ||||
| } | ||||
|  | ||||
| DebugInfoWriter::~DebugInfoWriter() = default; | ||||
| @ -293,6 +302,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) { | ||||
|   LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; | ||||
| } | ||||
|  | ||||
| DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { | ||||
|   if (writer_ == nullptr) { | ||||
|     std::string fileNamePrefix = getCvarString( | ||||
|         {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|     // Using std::unique_ptr here to auto-delete the writer object | ||||
|     // when the pointer itself is destroyed. | ||||
|     std::unique_ptr<DebugInfoWriter> writerPtr( | ||||
|         new DebugInfoWriter(fileNamePrefix, rank)); | ||||
|     DebugInfoWriter::registerWriter(std::move(writerPtr)); | ||||
|   } | ||||
|   return *writer_; | ||||
| } | ||||
|  | ||||
| void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) { | ||||
|   TORCH_CHECK_WITH( | ||||
|       DistBackendError, | ||||
|       hasWriterRegistered_.load() == false, | ||||
|       "debugInfoWriter already registered"); | ||||
|   hasWriterRegistered_.store(true); | ||||
|   writer_ = std::move(writer); | ||||
| } | ||||
|  | ||||
| std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr; | ||||
| std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false); | ||||
|  | ||||
| inline std::string pickle_str(const c10::IValue& v) { | ||||
|   std::vector<char> result; | ||||
|   { | ||||
| @ -337,10 +371,10 @@ struct NCCLTraceBuffer { | ||||
|                 // update state information | ||||
|     size_t pg_id_; | ||||
|     size_t seq_id_; // as tracked by the process group | ||||
|     const char* profiling_name_; | ||||
|     std::string profiling_name_; | ||||
|  | ||||
|     std::shared_ptr<torch::CapturedTraceback> traceback_; | ||||
|     // we borrow pointser to start_ and end_ so we can query the state | ||||
|     // we borrow pointers to start_ and end_ so we can query the state | ||||
|     // on reporting. However, once the event is completed, the call | ||||
|     // to `complete` will clear these. | ||||
|     EventList *start_, *end_; | ||||
| @ -348,8 +382,18 @@ struct NCCLTraceBuffer { | ||||
|     // timestamp when the entry was created, likely close to the time the work | ||||
|     // was 'enqueued'- not necessarily started | ||||
|     c10::time_t time_created_; | ||||
|     c10::optional<float> duration_; | ||||
|  | ||||
|     const char* state_ = "scheduled"; | ||||
|     // timestamp when our CPU threads discovered that the kernel started. | ||||
|     // will always be _after_ it actually started, and can be very late | ||||
|     // if the watchdog thread got stuck on CUDA APIs. | ||||
|     c10::optional<c10::time_t> time_discovered_started_; | ||||
|  | ||||
|     // timestamp when our CPU threads discovered that the kernel completed. | ||||
|     // will always be _after_ it actually complated, and can be the same time | ||||
|     // as the discovery of the start if the watchdog thread is stuck on CUDA | ||||
|     // APIs | ||||
|     c10::optional<c10::time_t> time_discovered_completed_; | ||||
|  | ||||
|     // size information for input/output tensors | ||||
|     c10::SmallVector<int, 4> input_dims_; | ||||
| @ -370,7 +414,7 @@ struct NCCLTraceBuffer { | ||||
|   c10::optional<size_t> record( | ||||
|       size_t pg_id, | ||||
|       size_t seq_id, | ||||
|       const char* profiling_name, | ||||
|       std::string profiling_name, | ||||
|       const std::vector<at::Tensor>& inputs, | ||||
|       const std::vector<at::Tensor>& outputs, | ||||
|       EventList* start, | ||||
| @ -386,7 +430,7 @@ struct NCCLTraceBuffer { | ||||
|         id_, | ||||
|         pg_id, | ||||
|         seq_id, | ||||
|         profiling_name, | ||||
|         std::move(profiling_name), | ||||
|         std::move(traceback), | ||||
|         std::move(start), | ||||
|         std::move(end), | ||||
| @ -424,8 +468,8 @@ struct NCCLTraceBuffer { | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       if (started) { | ||||
|         r.state_ = "started"; | ||||
|       if (started && !r.time_discovered_started_) { | ||||
|         r.time_discovered_started_ = c10::getTime(); | ||||
|       } | ||||
|     } | ||||
|     if (r.end_ != nullptr) { | ||||
| @ -436,8 +480,8 @@ struct NCCLTraceBuffer { | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       if (completed) { | ||||
|         r.state_ = "completed"; | ||||
|       if (completed && !r.time_discovered_completed_) { | ||||
|         r.time_discovered_completed_ = c10::getTime(); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @ -456,35 +500,96 @@ struct NCCLTraceBuffer { | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   void retire_id(c10::optional<size_t> id) { | ||||
|   /* | ||||
|   Mark an Event as completed and free its events. | ||||
|  | ||||
|   This is called by the watchdog thread, and is asynchronous from the | ||||
|   perspective of the main thread. | ||||
|  | ||||
|   compute_duration defaults to true since retire_id is only called in the | ||||
|   watchdog thread, which is currently a place we call cuda APIs which may hang, | ||||
|   but care should be taken to avoid computing duration in any function that must | ||||
|   never hang. (timing must also be enabled for compute_duration - see | ||||
|   TORCH_NCCL_ENABLE_TIMING). | ||||
|   */ | ||||
|   void retire_id(c10::optional<size_t> id, bool compute_duration = true) { | ||||
|     if (!enabled_ || !id) { | ||||
|       return; | ||||
|     } | ||||
|     std::lock_guard<std::mutex> guard(mutex_); | ||||
|     auto& entry = entries_.at(*id % max_entries_); | ||||
|     if (entry.id_ == *id) { | ||||
|       update_state(entry); | ||||
|       entry.retired_ = true; | ||||
|       entry.start_ = entry.end_ = nullptr; | ||||
|  | ||||
|     bool can_compute_duration = false; | ||||
|     EventList* startEvents = nullptr; | ||||
|     EventList* endEvents = nullptr; | ||||
|     c10::optional<float> duration = c10::nullopt; | ||||
|  | ||||
|     std::unique_lock<std::mutex> guard(mutex_); | ||||
|  | ||||
|     Entry* entry = &entries_.at(*id % max_entries_); | ||||
|     if (entry->id_ == *id) { | ||||
|       update_state(*entry); | ||||
|  | ||||
|       if (compute_duration) { | ||||
|         can_compute_duration = entry->time_discovered_completed_.has_value() && | ||||
|             entry->start_ && entry->end_; | ||||
|         startEvents = entry->start_; | ||||
|         endEvents = entry->end_; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (can_compute_duration) { | ||||
|       // Compute duration without without holding the lock, because | ||||
|       // cudaEventDuration() can hang, and we need to acquire the lock before we | ||||
|       // can dump(), which we never want to block. | ||||
|       guard.unlock(); | ||||
|       duration = getDurationFromFirstEvent(*startEvents, *endEvents); | ||||
|       guard.lock(); | ||||
|  | ||||
|       // Refresh the entry pointer, see if the entry has been overwritten | ||||
|       entry = &entries_.at(*id % max_entries_); | ||||
|       if (entry->id_ != *id) { | ||||
|         LOG(INFO) | ||||
|             << "retire_id abandoned for id " << *id | ||||
|             << ", event was overwritten while waiting to compute duration."; | ||||
|         return; | ||||
|       } | ||||
|       if (duration.has_value()) { | ||||
|         entry->duration_ = duration.value(); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     entry->retired_ = true; | ||||
|     entry->start_ = entry->end_ = nullptr; | ||||
|   } | ||||
|  | ||||
|   std::string dump() { | ||||
|   std::string dump( | ||||
|       const c10::optional<std::unordered_map< | ||||
|           std::string, | ||||
|           std::unordered_map<std::string, std::string>>>& ncclDumpMap) { | ||||
|     auto result = dump_entries(); | ||||
|     auto entries = new_list(); | ||||
|     c10::IValue pg_id_s = "pg_id"; | ||||
|     c10::IValue seq_id_s = "seq_id"; | ||||
|     c10::IValue profiling_name_s = "profiling_name"; | ||||
|     c10::IValue input_sizes_s = "input_sizes"; | ||||
|     c10::IValue output_sizes_s = "output_sizes"; | ||||
|     c10::IValue time_created_s = "time_created_us"; | ||||
|     c10::IValue entries_key = "entries"; | ||||
|     c10::IValue nccl_comm_key = "nccl_comm_state"; | ||||
|     c10::IValue version_key = "version"; | ||||
|     // Update whenever changing contents or formatting of the dump | ||||
|     // (minor when adding fields, major when changing existing fields) | ||||
|     c10::IValue version_val = "1.1"; | ||||
|  | ||||
|     c10::IValue frames_s = "frames"; | ||||
|     c10::IValue state_s = "state"; | ||||
|     c10::IValue line_s = "line"; | ||||
|     c10::IValue name_s = "name"; | ||||
|     c10::IValue filename_s = "filename"; | ||||
|     c10::IValue retired_s = "retired"; | ||||
|     c10::IValue pg_id_key = "pg_id"; | ||||
|     c10::IValue seq_id_key = "seq_id"; | ||||
|     c10::IValue profiling_name_key = "profiling_name"; | ||||
|     c10::IValue input_sizes_key = "input_sizes"; | ||||
|     c10::IValue output_sizes_key = "output_sizes"; | ||||
|     c10::IValue time_created_key = "time_created_ns"; | ||||
|     c10::IValue duration_key = "duration_ms"; | ||||
|  | ||||
|     c10::IValue frames_key = "frames"; | ||||
|     c10::IValue state_key = "state"; | ||||
|     c10::IValue line_key = "line"; | ||||
|     c10::IValue name_key = "name"; | ||||
|     c10::IValue filename_key = "filename"; | ||||
|     c10::IValue retired_key = "retired"; | ||||
|     c10::IValue time_discovered_started_key = "time_discovered_started_ns"; | ||||
|     c10::IValue time_discovered_completed_key = "time_discovered_completed_ns"; | ||||
|  | ||||
|     std::vector<torch::CapturedTraceback*> tracebacks; | ||||
|     for (auto& e : result) { | ||||
| @ -494,9 +599,9 @@ struct NCCLTraceBuffer { | ||||
|     std::vector<c10::IValue> all_frames; | ||||
|     for (const auto& f : stracebacks.all_frames) { | ||||
|       auto d = new_dict(); | ||||
|       d.insert(name_s, f.funcname); | ||||
|       d.insert(filename_s, f.filename); | ||||
|       d.insert(line_s, int64_t(f.lineno)); | ||||
|       d.insert(name_key, f.funcname); | ||||
|       d.insert(filename_key, f.filename); | ||||
|       d.insert(line_key, int64_t(f.lineno)); | ||||
|       all_frames.emplace_back(std::move(d)); | ||||
|     } | ||||
|  | ||||
| @ -504,10 +609,13 @@ struct NCCLTraceBuffer { | ||||
|       auto& e = result.at(i); | ||||
|       auto& tb = stracebacks.tracebacks.at(i); | ||||
|       auto dict = new_dict(); | ||||
|       dict.insert(pg_id_s, int64_t(e.pg_id_)); | ||||
|       dict.insert(seq_id_s, int64_t(e.seq_id_)); | ||||
|       dict.insert(profiling_name_s, e.profiling_name_); | ||||
|       dict.insert(time_created_s, int64_t(e.time_created_ / 1000)); | ||||
|       dict.insert(pg_id_key, int64_t(e.pg_id_)); | ||||
|       dict.insert(seq_id_key, int64_t(e.seq_id_)); | ||||
|       dict.insert(profiling_name_key, e.profiling_name_); | ||||
|       dict.insert(time_created_key, int64_t(e.time_created_)); | ||||
|       if (e.duration_) { | ||||
|         dict.insert(duration_key, *e.duration_); | ||||
|       } | ||||
|  | ||||
|       auto it = e.sizes_.begin(); | ||||
|       auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) { | ||||
| @ -523,19 +631,56 @@ struct NCCLTraceBuffer { | ||||
|         return sizes; | ||||
|       }; | ||||
|  | ||||
|       dict.insert(input_sizes_s, read_sizes(e.input_dims_)); | ||||
|       dict.insert(output_sizes_s, read_sizes(e.output_dims_)); | ||||
|       dict.insert(state_s, e.state_); | ||||
|       dict.insert(retired_s, e.retired_); | ||||
|       dict.insert(input_sizes_key, read_sizes(e.input_dims_)); | ||||
|       dict.insert(output_sizes_key, read_sizes(e.output_dims_)); | ||||
|       if (e.time_discovered_completed_.has_value()) { | ||||
|         dict.insert(state_key, "completed"); | ||||
|       } else if (e.time_discovered_started_.has_value()) { | ||||
|         dict.insert(state_key, "started"); | ||||
|       } else { | ||||
|         dict.insert(state_key, "scheduled"); | ||||
|       } | ||||
|  | ||||
|       dict.insert( | ||||
|           time_discovered_started_key, | ||||
|           e.time_discovered_started_.has_value() | ||||
|               ? int64_t(*e.time_discovered_started_) | ||||
|               : c10::IValue()); | ||||
|       dict.insert( | ||||
|           time_discovered_completed_key, | ||||
|           e.time_discovered_completed_.has_value() | ||||
|               ? int64_t(*e.time_discovered_completed_) | ||||
|               : c10::IValue()); | ||||
|       dict.insert(retired_key, e.retired_); | ||||
|  | ||||
|       auto frames = new_list(); | ||||
|       for (int64_t frame : tb) { | ||||
|         frames.push_back(all_frames.at(frame)); | ||||
|       } | ||||
|       dict.insert(frames_s, frames); | ||||
|       dict.insert(frames_key, frames); | ||||
|       entries.push_back(dict); | ||||
|     } | ||||
|     return pickle_str(entries); | ||||
|  | ||||
|     // convert ncclDumpMap into a dictionary | ||||
|     auto per_comm_dict = new_dict(); | ||||
|     if (ncclDumpMap.has_value()) { | ||||
|       for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { | ||||
|         auto inner_dict = new_dict(); | ||||
|         for (const auto& [key, value] : ncclDump) { | ||||
|           inner_dict.insert(key, value); | ||||
|         } | ||||
|         per_comm_dict.insert(ncclId, inner_dict); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     auto dict = new_dict(); | ||||
|     dict.insert(entries_key, entries); | ||||
|     dict.insert(version_key, version_val); | ||||
|     if (per_comm_dict.size() > 0) { | ||||
|       dict.insert(nccl_comm_key, per_comm_dict); | ||||
|     } | ||||
|  | ||||
|     return pickle_str(dict); | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -50,6 +50,35 @@ | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| #ifdef USE_C10D_NCCL | ||||
|  | ||||
| bool acquire_gil() { | ||||
|   // basically if this function can acquire the gil, it will return quickly. | ||||
|   // if not, it will hang forever.  The idea is to call this from a thread | ||||
|   // wrapped in a future, and then check the future after a timeout, to | ||||
|   // determine whether we're facing gil contention. | ||||
|   if (Py_IsInitialized()) { | ||||
|     pybind11::gil_scoped_acquire gil; | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   // If we end up here, its probably still a "pass" from the perspective of | ||||
|   // checking whether python is stuck. but currently we don't check the return | ||||
|   // value of this function anyway, just check whether it returned quickly vs | ||||
|   // timing out.  Taking a long time is the main sign of trouble.  Fast return | ||||
|   // with true or with false is both OK from the perspective of debugging python | ||||
|   // hangs. | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| bool registerGilChecker() { | ||||
|   c10d::get_gil_checker() = &acquire_gil; | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| static bool registered = registerGilChecker(); | ||||
| #endif // USE_C10D_NCCL | ||||
|  | ||||
| // Wrapper to ensure GIL is released before destructing ProcessGroupGloo | ||||
| // TODO: move this somewhere more generally useful | ||||
| template <typename T> | ||||
| @ -1033,6 +1062,29 @@ Example:: | ||||
|     >>> store.add("first_key", 6) | ||||
|     >>> # Should return 7 | ||||
|     >>> store.get("first_key") | ||||
| )") | ||||
|           .def( | ||||
|               "check", | ||||
|               &::c10d::Store::check, | ||||
|               py::call_guard<py::gil_scoped_release>(), | ||||
|               R"( | ||||
| The call to check whether a given list of ``keys`` have value stored in | ||||
| the store. This call immediately returns in normal cases but still suffers | ||||
| from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed. | ||||
| Calling :meth:`~torch.distributed.store.check` with a list of keys that | ||||
| one wants to check whether stored in the store or not. | ||||
|  | ||||
| Arguments: | ||||
|     keys (lisr[str]): The keys to query whether stored in the store. | ||||
|  | ||||
| Example:: | ||||
|     >>> import torch.distributed as dist | ||||
|     >>> from datetime import timedelta | ||||
|     >>> # Using TCPStore as an example, other store types can also be used | ||||
|     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) | ||||
|     >>> store.add("first_key", 1) | ||||
|     >>> # Should return 7 | ||||
|     >>> store.check(["first_key"]) | ||||
| )") | ||||
|           .def( | ||||
|               "delete_key", | ||||
| @ -1404,7 +1456,11 @@ Arguments: | ||||
|       .def_property_readonly( | ||||
|           "underlying_store", | ||||
|           &::c10d::PrefixStore::getUnderlyingStore, | ||||
|           R"(Gets the underlying store object that PrefixStore wraps around.)"); | ||||
|           R"(Gets the underlying store object that PrefixStore wraps around.)") | ||||
|       .def_property_readonly( | ||||
|           "_underlying_non_prefix_store", | ||||
|           &::c10d::PrefixStore::getUnderlyingNonPrefixStore, | ||||
|           R"(Recursively to get the store before layers of wrapping with PrefixStore.)"); | ||||
|  | ||||
|   auto processGroup = | ||||
|       py::class_< | ||||
|  | ||||
| @ -28,7 +28,7 @@ def tail_logfile( | ||||
|             return | ||||
|         time.sleep(interval_sec) | ||||
|  | ||||
|     with open(file) as fp: | ||||
|     with open(file, errors="replace") as fp: | ||||
|         while True: | ||||
|             line = fp.readline() | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	