mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	Compare commits
	
		
			55 Commits
		
	
	
		
			codex/fix-
			...
			flight_5
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b33a283e9a | |||
| 7a551d81e5 | |||
| 1515a90475 | |||
| 4882ec2a91 | |||
| 972b8060bd | |||
| 3e7683ae18 | |||
| f2e9ec2dc5 | |||
| dde4324d8e | |||
| 94c079104d | |||
| a6afee6d94 | |||
| d092857531 | |||
| 6aad5e444a | |||
| c54ce9313b | |||
| 1fe59f4ef7 | |||
| e693fb2bb1 | |||
| 4fe510baf6 | |||
| 7c507b78c4 | |||
| 0019901601 | |||
| 18be18535b | |||
| 2729367313 | |||
| 33537aae24 | |||
| dcdb1337dd | |||
| 9cf0f2bd59 | |||
| 1d2e877c05 | |||
| f27b979b0c | |||
| f30d6047ad | |||
| 75311510ef | |||
| dbd6094d05 | |||
| 397b9d47e9 | |||
| 36a01a8ab9 | |||
| ee336cf58a | |||
| a9e2e745d7 | |||
| ab4df89eea | |||
| 9d02ebe876 | |||
| b61e01cce9 | |||
| f7ce61ba53 | |||
| e7bae15ab1 | |||
| e71b422908 | |||
| 389940ce60 | |||
| b2237a7c85 | |||
| ef5dfe3f3e | |||
| e303dc3c08 | |||
| 265efad2de | |||
| 60f0455905 | |||
| 4898313791 | |||
| f4da9adf6b | |||
| 8f7f35273e | |||
| 44ec9612ed | |||
| 4d3bea2b29 | |||
| 0bcdddc3c1 | |||
| 28b6220312 | |||
| 210b7b65e2 | |||
| 4da10b5cd3 | |||
| f09763814f | |||
| 80923ed5a6 | 
| @ -1732,7 +1732,7 @@ if(BUILD_TEST) | ||||
|   foreach(test_src ${Caffe2_CPU_TEST_SRCS}) | ||||
|     get_filename_component(test_name ${test_src} NAME_WE) | ||||
|     add_executable(${test_name} "${test_src}") | ||||
|     target_link_libraries(${test_name} torch_library gtest_main) | ||||
|     target_link_libraries(${test_name} torch_library gtest_main stdc++) | ||||
|     target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>) | ||||
|     target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>) | ||||
|     target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) | ||||
|  | ||||
| @ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) { | ||||
| // Extend the nested class outside the parent class | ||||
| class TestDebugInfoWriter : public c10d::DebugInfoWriter { | ||||
|  public: | ||||
|   TestDebugInfoWriter() : DebugInfoWriter(0) {} | ||||
|   TestDebugInfoWriter(std::string namePrefix) | ||||
|       : DebugInfoWriter(namePrefix, 0) {} | ||||
|  | ||||
|   void write(const std::string& ncclTrace) override { | ||||
|     traces_.assign(ncclTrace.begin(), ncclTrace.end()); | ||||
| @ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { | ||||
|   // The storer here is very similar to the fallback storer. | ||||
|   // The only difference is that we are storing traces also in memory for | ||||
|   // validation. | ||||
|   std::string fileNamePrefix = c10d::getCvarString( | ||||
|       {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|   std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr = | ||||
|       std::make_unique<TestDebugInfoWriter>(); | ||||
|       std::make_unique<TestDebugInfoWriter>(fileNamePrefix); | ||||
|   std::vector<uint8_t>& traces = wrterForTestPtr->getTraces(); | ||||
|   pg.registerDebugInfoWriter(std::move(wrterForTestPtr)); | ||||
|   c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr)); | ||||
|  | ||||
|   // Normal collective case. | ||||
|   auto work = pg.allreduce(tensors_); | ||||
| @ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { | ||||
| class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest { | ||||
|  protected: | ||||
|   void SetUp() override { | ||||
|     // TODO (kwen2501) | ||||
|     GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; " | ||||
|                  << "will rewrite them after refactoring Work queues."; | ||||
|     ProcessGroupNCCLErrorsTest::SetUp(); | ||||
|     std::string timeInterval = std::to_string(heartBeatIntervalInSec); | ||||
|     ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); | ||||
|  | ||||
| @ -4,9 +4,10 @@ import torch | ||||
| import torch.distributed as dist | ||||
| import torch.distributed._functional_collectives as funcol | ||||
|  | ||||
| from torch.distributed._tensor import DTensor | ||||
| from torch.distributed._tensor.placement_types import Shard | ||||
| from torch.distributed.checkpoint._state_dict_utils import ( | ||||
| from torch.distributed._state_dict_utils import ( | ||||
|     _check_state_dict_similarity, | ||||
|     _copy_state_dict, | ||||
|     _create_cpu_state_dict, | ||||
|     _gather_state_dict, | ||||
|     _offload_state_dict_to_cpu, | ||||
| ) | ||||
| @ -115,6 +116,58 @@ class TestStateDictUtils(DTensorTestBase): | ||||
|         } | ||||
|         self.assertEqual(state_dict, _gather_state_dict(dist_state_dict)) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_create_cpu_state_dict(self): | ||||
|         device = torch.device("cuda") | ||||
|         buffer = io.BytesIO() | ||||
|         torch.save(torch.ones(10), buffer) | ||||
|         buffer.seek(0) | ||||
|         state_dict = { | ||||
|             "tensor1": torch.arange(10, device=device), | ||||
|             "tensor2": torch.ones(10, device=device), | ||||
|             "non_tensor_bytes_io": copy.deepcopy(buffer), | ||||
|             "non_tensor_bytes": buffer.read(), | ||||
|             "step": torch.tensor(7, dtype=torch.float), | ||||
|             "lr": 1.5, | ||||
|             "nested": {"list": [1, 2, 3, 4]}, | ||||
|         } | ||||
|  | ||||
|         def _verify(cpu_state_dict): | ||||
|             # Verify the correctness of _check_state_dict_similarity() | ||||
|             self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict)) | ||||
|             tensor1 = cpu_state_dict["tensor1"] | ||||
|             cpu_state_dict["tensor1"] = torch.arange(11) | ||||
|             self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict)) | ||||
|             cpu_state_dict["tensor1"] = tensor1 | ||||
|  | ||||
|             _copy_state_dict(state_dict, cpu_state_dict) | ||||
|  | ||||
|             # Verify if _copy_state_dict works | ||||
|             for v in cpu_state_dict.values(): | ||||
|                 if isinstance(v, torch.Tensor): | ||||
|                     self.assertFalse(v.is_cuda) | ||||
|             self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10)) | ||||
|             self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10)) | ||||
|             buffer.seek(0) | ||||
|             cpu_state_dict["non_tensor_bytes_io"].seek(0) | ||||
|             self.assertEqual( | ||||
|                 cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read() | ||||
|             ) | ||||
|             buffer.seek(0) | ||||
|             self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read()) | ||||
|             self.assertEqual(cpu_state_dict["lr"], 1.5) | ||||
|             self.assertEqual(cpu_state_dict["step"], 7) | ||||
|             self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]}) | ||||
|  | ||||
|         cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True) | ||||
|         _verify(cpu_state_dict) | ||||
|         cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True) | ||||
|         _verify(cpu_state_dict) | ||||
|         cpu_state_dict = _create_cpu_state_dict( | ||||
|             state_dict, share_memory=True, pin_memory=True | ||||
|         ) | ||||
|         _verify(cpu_state_dict) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -11,6 +11,7 @@ import tempfile | ||||
| import threading | ||||
| import pickle | ||||
| import time | ||||
| import json | ||||
| import warnings | ||||
| from contextlib import contextmanager | ||||
| from datetime import datetime, timedelta | ||||
| @ -1334,6 +1335,19 @@ class ProcessGroupNCCLTest(MultiProcessTestCase): | ||||
|         self.assertEqual(tensor, original_tensor) | ||||
|  | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") | ||||
|     def test_set_process_group_desc(self): | ||||
|         store = c10d.FileStore(self.file_name, self.world_size) | ||||
|         device = torch.device(f'cuda:{self.rank}') | ||||
|         pg_default = self._create_process_group_nccl(store, self.opts(), device_id=device) | ||||
|         self.assertEqual(pg_default.group_desc, "default_pg") | ||||
|         pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") | ||||
|         self.assertEqual(pg_1.group_desc, "test_purpose") | ||||
|         pg_2 = c10d.new_group([0, 1]) | ||||
|         self.assertEqual(pg_2.group_desc, "undefined") | ||||
|  | ||||
|  | ||||
| class DistributedDataParallelTest( | ||||
|     test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase | ||||
| ): | ||||
| @ -3542,11 +3556,12 @@ class SparseCollective(MultiProcessTestCase): | ||||
| class NCCLTraceTestBase(MultiProcessTestCase): | ||||
|     def setUp(self): | ||||
|         super().setUp() | ||||
|         os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0' | ||||
|         os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'  # see 'timing_enabled' parametrized tests | ||||
|         os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10' | ||||
|         os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1' | ||||
|         self.tempdir = tempfile.TemporaryDirectory() | ||||
|         os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename() | ||||
|         os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename() | ||||
|         self._spawn_processes() | ||||
|  | ||||
|     @classmethod | ||||
| @ -3617,28 +3632,50 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
|     def test_short(self): | ||||
|     @parametrize("timing_enabled", [True, False]) | ||||
|     def test_short(self, timing_enabled): | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             return | ||||
|         pg = self._create_process_group_nccl() | ||||
|         if timing_enabled: | ||||
|             pg._enable_collectives_timing() | ||||
|         device = self.local_device | ||||
|         a = torch.full((3, 4), float(self.rank), device=device) | ||||
|         for i in range(2): | ||||
|             f = pg.allreduce(a) | ||||
|         f.wait() | ||||
|         torch.cuda.synchronize(device=device) | ||||
|  | ||||
|         # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api | ||||
|         time.sleep(1) | ||||
|  | ||||
|         t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|         ver = t['version'] | ||||
|         self.assertEqual(ver, "1.1") | ||||
|         t = t['entries'] | ||||
|         self.assertEqual(len(t), 2) | ||||
|         last = t[-1] | ||||
|         self.assertEqual(last['process_group'], ('0', 'default_pg')) | ||||
|         self.assertEqual(last['state'], 'completed') | ||||
|         s = last['time_discovered_started_ns'] | ||||
|         f = last['time_discovered_completed_ns'] | ||||
|         self.assertIsNotNone(f) | ||||
|         if timing_enabled: | ||||
|             self.assertIsNotNone(s) | ||||
|             self.assertTrue(s <= f) | ||||
|         self.assertIn('test_c10d_nccl.py', str(last['frames'])) | ||||
|         self.assertEqual(last['input_sizes'], ((3, 4),)) | ||||
|         self.assertEqual(last['output_sizes'], ((3, 4),)) | ||||
|         self.assertEqual(last['seq_id'], 2) | ||||
|         now = datetime.now() | ||||
|         event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000) | ||||
|         event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000) | ||||
|         before_test = now - timedelta(minutes=1) | ||||
|         self.assertTrue(before_test < event_created_time < now) | ||||
|         if timing_enabled: | ||||
|             # very loose bounds, measured 0.036 ms on devgpu | ||||
|             self.assertTrue(0 < last['duration_ms'] < 100) | ||||
|         else: | ||||
|             self.assertTrue("duration_ms" not in last) | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
| @ -3698,9 +3735,11 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|         f.wait() | ||||
|         torch.cuda.synchronize(device=device) | ||||
|         t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|         t = t['entries'] | ||||
|         self.assertEqual(len(t), 10) | ||||
|         first = t[0] | ||||
|         last = t[-1] | ||||
|         self.assertEqual(last['profiling_name'], 'nccl:all_reduce') | ||||
|         self.assertEqual(last['state'], 'completed') | ||||
|         self.assertIn('test_c10d_nccl.py', str(last['frames'])) | ||||
|         self.assertEqual(last['input_sizes'], ((3, 4),)) | ||||
| @ -3732,6 +3771,8 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|                 pg.allreduce(a).wait() | ||||
|             e.synchronize() | ||||
|             t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|             t = t['entries'] | ||||
|             self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce') | ||||
|             if self.rank == 0: | ||||
|                 self.assertEqual(t[-1]['seq_id'], 1) | ||||
|                 self.assertEqual(t[-1]['state'], 'completed') | ||||
| @ -3773,12 +3814,15 @@ class NCCLTraceTest(NCCLTraceTestBase): | ||||
|                 # give the other thread some time to fill the cuda buffer | ||||
|                 time.sleep(5) | ||||
|                 t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce') | ||||
|                 if self.rank == 0: | ||||
|                     self.assertEqual(t[-1]['seq_id'], 1) | ||||
|                     self.assertEqual(t[-1]['state'], 'completed') | ||||
|                 else: | ||||
|                     self.assertEqual(t[-1]['seq_id'], 2) | ||||
|                     self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled)) | ||||
|                     self.assertIsNone(t[-1]['time_discovered_completed_ns']) | ||||
|                 # this will eventually cause the missing rank 0 | ||||
|                 # to continue which will unblock the non-zero ranks | ||||
|                 self.parent.send('next') | ||||
| @ -3832,6 +3876,10 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase): | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
|     @parametrize("timing_enabled", [True, False]) | ||||
|     def test_timeout_dumps(self, timing_enabled): | ||||
|         # dump on heartbeatmonitor thread | ||||
|         os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000' | ||||
|         # need rank0 to crash before looking for its output file | ||||
|         os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1' | ||||
|  | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             # wait for rank0 to crash before looking for its output file | ||||
| @ -3839,6 +3887,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase): | ||||
|             self.assertEqual(self._wait_process(0, timeout=90), -6) | ||||
|             with open(self._trace_name(rank=0), 'rb') as f: | ||||
|                 t = pickle.load(f) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(len(t), 2) | ||||
|                 self.assertEqual(t[0]['seq_id'], 1) | ||||
|                 self.assertEqual(t[0]['state'], 'completed') | ||||
| @ -3868,7 +3917,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase): | ||||
| instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout) | ||||
| instantiate_parametrized_tests(NCCLTraceTest) | ||||
|  | ||||
| class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
| class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|     def _check_return_codes(self, elapsed_time): | ||||
|         # the base test infra assumes processes exit with matching return codes, | ||||
|         # but we want rank0 to abort and rank1 to exit cleanly in this test | ||||
| @ -3877,7 +3926,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") | ||||
|     def test_timeout_dumps_on_idle_ranks(self): | ||||
|     def test_timeout_dumps_on_stuck_ranks(self): | ||||
|  | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             # wait for both rank0 and 1 to crash before looking for both ranks' output | ||||
| @ -3888,20 +3937,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|             self.assertTrue(os.path.exists(self._trace_name(rank=0))) | ||||
|             with open(self._trace_name(rank=0), 'rb') as f: | ||||
|                 t = pickle.load(f) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(len(t), 2) | ||||
|             with open(self._trace_name(rank=1), 'rb') as f: | ||||
|                 t = pickle.load(f) | ||||
|                 t = t['entries'] | ||||
|                 self.assertEqual(len(t), 1) | ||||
|                 self.assertEqual(t[0]['seq_id'], 1) | ||||
|                 self.assertEqual(t[0]['state'], 'completed') | ||||
|             return | ||||
|  | ||||
|         # Set heartbeat timeout to a shorter one (default timeout is 2 min). | ||||
|         os.environ[ | ||||
|             "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC" | ||||
|         ] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}" | ||||
|         pg = self._create_process_group_nccl() | ||||
|  | ||||
|         device = self.local_device | ||||
|         with torch.cuda.device(device): | ||||
|             a = torch.full((3, 4), float(self.rank), device=device) | ||||
| @ -3910,12 +3956,68 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase): | ||||
|             if self.rank == 0: | ||||
|                 pg.allreduce(a).wait() | ||||
|  | ||||
|             # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly | ||||
|             # rank 0 will get stuck, timeout and then signal a timeout to all ranks. | ||||
|             torch.cuda.synchronize() | ||||
|  | ||||
|             # Force rank 1 to idle so that it also gets debug info dump triggered. | ||||
|             if self.rank == 1: | ||||
|                 time.sleep(6) | ||||
|                 # Force rank 1 to idle so that it will eventually timeout as well after | ||||
|                 # getting the global signal to dump the debugging info. | ||||
|                 time.sleep(600) | ||||
|  | ||||
| class NcclErrorDumpTest(NCCLTraceTestBase): | ||||
|     def _wait_process(self, rank, timeout): | ||||
|         try: | ||||
|             self.processes[rank].join(timeout) | ||||
|             return self.processes[rank].exitcode | ||||
|         except TimeoutError: | ||||
|             return None | ||||
|  | ||||
|     def _check_return_codes(self, elapsed_time): | ||||
|         # the base test infra assumes processes exit with matching return codes, | ||||
|         # but we want rank0 to abort with exception and rank1 to exit with exit 1 | ||||
|         self.assertEqual(self.processes[0].exitcode, -6) | ||||
|         self.assertEqual(self.processes[1].exitcode, 1) | ||||
|  | ||||
|     @requires_nccl() | ||||
|     @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @skip_if_rocm | ||||
|     def test_nccl_errors_dump(self): | ||||
|         os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" | ||||
|         os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000' | ||||
|         os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1' | ||||
|         # need rank0 to dump before abort | ||||
|         os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '5' | ||||
|  | ||||
|         if self.rank == self.MAIN_PROCESS_RANK: | ||||
|             # wait for both rank0 and 1 to crash before looking for dump | ||||
|             self.assertEqual(self._wait_process(0, timeout=90), -6) | ||||
|             self.assertEqual(self._wait_process(1, timeout=90), 1) | ||||
|             # verify that the trace file exists for rank0 | ||||
|             self.assertTrue(os.path.exists(self._trace_name(rank=0))) | ||||
|             return | ||||
|  | ||||
|         store = c10d.FileStore(self.file_name, self.world_size) | ||||
|         process_group = c10d.ProcessGroupNCCL( | ||||
|             store, | ||||
|             self.rank, | ||||
|             self.world_size, | ||||
|             timeout=timedelta(seconds=10), | ||||
|         ) | ||||
|         process_group.allreduce(torch.rand(10).cuda(self.rank)) | ||||
|         if self.rank == 0: | ||||
|             work = process_group.allreduce(torch.rand(10).cuda(self.rank)) | ||||
|             # expect an error to be raised | ||||
|             with self.assertRaisesRegex(dist.DistBackendError, ""): | ||||
|                 # Block the current stream on the NCCL stream | ||||
|                 work.wait() | ||||
|                 # Run some GPU operations | ||||
|                 a = torch.rand(10).cuda(self.rank) | ||||
|         elif self.rank == 1: | ||||
|             # Clean up structures (ex: files for FileStore before going down) | ||||
|             del process_group | ||||
|             sys.exit(1) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     assert ( | ||||
|  | ||||
| @ -71,7 +71,7 @@ class StoreTestBase: | ||||
|     def _create_store(self, i): | ||||
|         raise RuntimeError("not implemented") | ||||
|  | ||||
|     def _test_set_get(self, fs): | ||||
|     def _test_set_get_check(self, fs): | ||||
|         fs.add("key", 1) | ||||
|         fs.add("key", 2) | ||||
|         fs.add("key", 3) | ||||
| @ -90,14 +90,16 @@ class StoreTestBase: | ||||
|         self.assertEqual(b"value1", fs.get("key1")) | ||||
|         self.assertEqual(b"value2", fs.get("key2")) | ||||
|         self.assertEqual(b"21", fs.get("key3")) | ||||
|         self.assertTrue(fs.check(["key3"])) | ||||
|         self.assertFalse(fs.check(["Randomkey3"])) | ||||
|  | ||||
|         fs.set("-key3", "7") | ||||
|         self.assertEqual(b"7", fs.get("-key3")) | ||||
|         fs.delete_key("-key3") | ||||
|         self.assertEqual(fs.num_keys(), self.num_keys_total) | ||||
|  | ||||
|     def test_set_get(self): | ||||
|         self._test_set_get(self._create_store()) | ||||
|     def test_set_get_check(self): | ||||
|         self._test_set_get_check(self._create_store()) | ||||
|  | ||||
|     def _test_compare_set(self, store): | ||||
|         missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") | ||||
| @ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase): | ||||
|     def num_keys_total(self): | ||||
|         return 6 | ||||
|  | ||||
|     def test_underlying_non_prefix_store(self): | ||||
|         store = self._create_store() | ||||
|         wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store)) | ||||
|         self.assertEqual(self.tcpstore, store._underlying_non_prefix_store) | ||||
|         self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store) | ||||
|  | ||||
| class MyPythonStore(dist.Store): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
| @ -463,6 +463,7 @@ class ProcessGroup: | ||||
|         backend: Optional[ProcessGroup], | ||||
|     ) -> None: ... | ||||
|     def _set_group_name(self, name: str) -> None: ... | ||||
|     def _set_group_desc(self, desc: str) -> None: ... | ||||
|     def name(self) -> str: ... | ||||
|     def _has_hooks(self) -> bool: ... | ||||
|     def _wait_for_pending_works(self) -> None: ... | ||||
| @ -471,6 +472,10 @@ class ProcessGroup: | ||||
|     def bound_device_id(self) -> Optional[torch.device]: ... | ||||
|     @bound_device_id.setter | ||||
|     def bound_device_id(self, device: Optional[torch.device]) -> None: ... | ||||
|     @property | ||||
|     def group_name(self) -> str: ... | ||||
|     @property | ||||
|     def group_desc(self) -> str: ... | ||||
|  | ||||
| class ProcessGroupRoundRobin(ProcessGroup): ... | ||||
|  | ||||
|  | ||||
| @ -369,6 +369,14 @@ class TORCH_API Backend : public torch::CustomClassHolder { | ||||
|     return pg_name_; | ||||
|   } | ||||
|  | ||||
|   void setGroupDesc(const std::string& desc) { | ||||
|     pg_desc_ = desc; | ||||
|   } | ||||
|  | ||||
|   const std::string& getGroupDesc() const { | ||||
|     return pg_desc_; | ||||
|   } | ||||
|  | ||||
|   // See similar functions in ProcessGroup.hpp for context. | ||||
|   c10::optional<at::Device> getBoundDeviceId() const { | ||||
|     return bound_device_id_; | ||||
| @ -399,6 +407,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { | ||||
|   // remains the same across use of this process group. | ||||
|   DebugLevel dist_debug_level_; | ||||
|   std::string pg_name_; | ||||
|   std::string pg_desc_; | ||||
|  | ||||
|   std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_; | ||||
|  | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
|  | ||||
| #include <ATen/ATen.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/Optional.h> | ||||
| #include <nccl.h> | ||||
| @ -175,14 +176,29 @@ std::string getNcclErrorDetailStr( | ||||
|     c10::optional<std::string> processGroupFailureReason = c10::nullopt); | ||||
|  | ||||
| // Write NCCL debug info to local disk or any storage users define. | ||||
| // There are some constrains we set for the debug info writer: | ||||
| // 1. The writer should only be registered once. | ||||
| // 2. Once registered, users cannot change it including un-register. | ||||
| // 3. It is recommended to register the customized writer in the trainer setup, | ||||
| //    If users don't register before calling launchAsyncDebugDump, then users | ||||
| //    lose the chance to register (and the default writer will be | ||||
| //    auto-registered). | ||||
| class TORCH_API DebugInfoWriter { | ||||
|  public: | ||||
|   DebugInfoWriter(int rank); | ||||
|   virtual ~DebugInfoWriter(); | ||||
|   virtual void write(const std::string& ncclTrace); | ||||
|   static DebugInfoWriter& getWriter(int rank); | ||||
|   static void registerWriter(std::unique_ptr<DebugInfoWriter> writer); | ||||
|  | ||||
|  protected: | ||||
|   DebugInfoWriter(std::string namePrefix, int rank) { | ||||
|     filename_ = c10::str(namePrefix, rank); | ||||
|   } | ||||
|   std::string filename_; | ||||
|  | ||||
|  private: | ||||
|   static std::unique_ptr<DebugInfoWriter> writer_; | ||||
|   static std::atomic<bool> hasWriterRegistered_; | ||||
| }; | ||||
|  | ||||
| // RAII wrapper for NCCL communicator | ||||
| @ -267,6 +283,18 @@ class NCCLComm { | ||||
|   } | ||||
| #endif | ||||
|  | ||||
| #if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP) | ||||
|   std::unordered_map<std::string, std::string> ncclCommDump() { | ||||
|     std::unordered_map<std::string, std::string> dump; | ||||
|     if (isAborted()) { | ||||
|       LOG(INFO) << "Communicator was aborted before trying to dump its state."; | ||||
|       return dump; | ||||
|     } | ||||
|     C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt); | ||||
|     return dump; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   ncclUniqueId getNcclId() { | ||||
|     return ncclId_; | ||||
|   } | ||||
| @ -322,6 +350,9 @@ class NCCLComm { | ||||
|     // Set true failure reason if provided by ProcessGroupNCCL (e.g. work | ||||
|     // timeout) | ||||
|     commFailureReason_ = commFailureReason; | ||||
|     LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " | ||||
|               << (commFailureReason ? *commFailureReason | ||||
|                                     : "No abort reason provided."); | ||||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||||
|     C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); | ||||
| #else | ||||
| @ -421,6 +452,8 @@ class NCCLComm { | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   friend class ProcessGroupNCCL; | ||||
|  | ||||
|  protected: | ||||
|   ncclComm_t ncclComm_; | ||||
|   // Unique nccl_id for this communicator. | ||||
|  | ||||
| @ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() { | ||||
|   return store_; | ||||
| } | ||||
|  | ||||
| c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() { | ||||
|   c10::intrusive_ptr<Store> store = store_; | ||||
|  | ||||
|   while (store) { | ||||
|     // Attempt to dynamically cast to PrefixStore | ||||
|     PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get()); | ||||
|     if (asPrefixStore) { | ||||
|       store = asPrefixStore->getUnderlyingStore(); | ||||
|     } else { | ||||
|       break; // We've reached a non-PrefixStore | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       store != nullptr, "Underlying Non-PrefixStore shouldn't be null."); | ||||
|   return store; | ||||
| } | ||||
|  | ||||
| } // namespace c10d | ||||
|  | ||||
| @ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store { | ||||
|  | ||||
|   c10::intrusive_ptr<Store> getUnderlyingStore(); | ||||
|  | ||||
|   // Recursively to fetch the store before layers of wrapping with PrefixStore. | ||||
|   c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore(); | ||||
|  | ||||
|  protected: | ||||
|   std::string prefix_; | ||||
|   c10::intrusive_ptr<Store> store_; | ||||
|  | ||||
| @ -165,6 +165,18 @@ void ProcessGroup::setGroupName(const std::string& name) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| const std::string& ProcessGroup::getGroupDesc() const { | ||||
|   return pg_desc_; | ||||
| } | ||||
|  | ||||
| void ProcessGroup::setGroupDesc(const std::string& name) { | ||||
|   pg_desc_ = name; | ||||
|   // Also set the group desc for all backends | ||||
|   for (auto& kv : deviceTypeToBackend_) { | ||||
|     kv.second->setGroupDesc(name); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void ProcessGroup::enableCollectivesTiming() { | ||||
|   for (auto& kv : deviceTypeToBackend_) { | ||||
|     kv.second->enableCollectivesTiming(); | ||||
|  | ||||
| @ -694,6 +694,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { | ||||
|  | ||||
|   const std::string& getGroupName() const; | ||||
|   void setGroupName(const std::string& name); | ||||
|   const std::string& getGroupDesc() const; | ||||
|   void setGroupDesc(const std::string& name); | ||||
|   void enableCollectivesTiming(); | ||||
|  | ||||
|   void release_resources() override; | ||||
| @ -724,6 +726,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { | ||||
|   const int size_; | ||||
|   const c10::intrusive_ptr<Options> options_; | ||||
|   const BackendType backendType_; | ||||
|   std::string pg_desc_; | ||||
|  | ||||
|   // Debug level setting. It is parsed once when ProcessGroup is constructed and | ||||
|   // remains the same across use of this process group. | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,7 +1,15 @@ | ||||
| #pragma once | ||||
|  | ||||
| #if defined(__linux__) | ||||
| #include <fcntl.h> | ||||
| #include <sys/stat.h> | ||||
| #include <sys/types.h> | ||||
| #include <unistd.h> | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_C10D_NCCL | ||||
|  | ||||
| #include <atomic> | ||||
| #include <chrono> | ||||
| #include <future> | ||||
| #include <iostream> | ||||
| @ -12,6 +20,7 @@ | ||||
|  | ||||
| #include <torch/csrc/distributed/c10d/Backend.hpp> | ||||
| #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> | ||||
| #include <torch/csrc/distributed/c10d/PrefixStore.hpp> | ||||
| #include <torch/csrc/distributed/c10d/Store.hpp> | ||||
|  | ||||
| #include <ATen/DynamicLibrary.h> | ||||
| @ -26,52 +35,74 @@ | ||||
| #include <torch/custom_class.h> | ||||
|  | ||||
| namespace c10d { | ||||
| // Environment variable which controls whether we perform a NCCL healt check | ||||
| // Control whether we perform a NCCL health check or not | ||||
| // which ensures communicators are healthy at the beginning of init. | ||||
| static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = { | ||||
|     "TORCH_ENABLE_NCCL_HEALTH_CHECK", | ||||
|     "ENABLE_NCCL_HEALTH_CHECK"}; | ||||
|  | ||||
| // Environment variable which controls whether or not wait() is blocking or | ||||
| // non-blocking. | ||||
| // Control whether or not wait() is blocking or non-blocking. | ||||
| static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = { | ||||
|     "TORCH_NCCL_BLOCKING_WAIT", | ||||
|     "NCCL_BLOCKING_WAIT"}; | ||||
|  | ||||
| // Environment variable which controls whether or not we perform Async Error | ||||
| // Handling with NCCL. | ||||
| // Control whether or not we perform Async Error Handling with NCCL. | ||||
| static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = { | ||||
|     "TORCH_NCCL_ASYNC_ERROR_HANDLING", | ||||
|     "NCCL_ASYNC_ERROR_HANDLING"}; | ||||
|  | ||||
| // Environment Variable to control whether dumping debug info on watchdog | ||||
| // Control whether dumping debug info on watchdog | ||||
| // timeout is enabled. This variable must be set together with | ||||
| // TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0. | ||||
| static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = { | ||||
|     "TORCH_NCCL_DUMP_ON_TIMEOUT"}; | ||||
|  | ||||
| // Environment Variable to control whether Desync Debug is enabled. | ||||
| // This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING. | ||||
| // Control whether Desync Debug is enabled. This variable must be set | ||||
| // together with TORCH_NCCL_ASYNC_ERROR_HANDLING. | ||||
| static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = { | ||||
|     "TORCH_NCCL_DESYNC_DEBUG", | ||||
|     "NCCL_DESYNC_DEBUG"}; | ||||
|  | ||||
| // Enable recording start-events for all ProcessGroupNCCL collectives, and | ||||
| // compute accurate collective timing per-collective. (Note: end-events are | ||||
| // recorded by default. Turn on this flag can increase chances of a watchdog | ||||
| // hang due to performing a CUDA event query which eventually calls | ||||
| // cudaEventElapsedTime() API. | ||||
| static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = { | ||||
|     "TORCH_NCCL_ENABLE_TIMING", | ||||
|     "NCCL_ENABLE_TIMING"}; | ||||
|  | ||||
| // Enable monitoring thread which aborts the process when the ProcessGroupNCCL | ||||
| // Watchdog thread gets stuck and no heartbeat is detected after | ||||
| // TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL | ||||
| // APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged | ||||
| // time than necessary tying up cluster resources. | ||||
| static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = { | ||||
|     "TORCH_NCCL_ENABLE_MONITORING"}; | ||||
|  | ||||
| // Control the watchdog heartbeat timeout period after which the monitoring | ||||
| // thread will abort the process. | ||||
| static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = { | ||||
|     "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"}; | ||||
|  | ||||
| // The maximum number of events we store in the flight recorder's ring buffer. | ||||
| // (One event could be the start or end of a collective, for example). | ||||
| static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = { | ||||
|     "TORCH_NCCL_TRACE_BUFFER_SIZE"}; | ||||
|  | ||||
| // Control how much extra time we will wait for dumping the debugging info | ||||
| // before we exit and throws timeout exception. | ||||
| static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = { | ||||
|     "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"}; | ||||
|  | ||||
| // Control the interval inside the watchdog thread to check the coordinated | ||||
| // signal from other ranks, e.g. to dump the debugging information. | ||||
| static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = { | ||||
|     "TORCH_NCCL_COORD_CHECK_MILSEC"}; | ||||
|  | ||||
| constexpr const char* NCCL_BACKEND_NAME = "nccl"; | ||||
|  | ||||
| constexpr const char* TIMEOUT_DUMP = "timeout_dump"; | ||||
| constexpr const char* EXCEPTION_DUMP = "exception_dump"; | ||||
|  | ||||
| constexpr auto kProcessGroupNCCLDefaultTimeout = | ||||
|     std::chrono::milliseconds(10 * 60 * 1000); | ||||
| @ -110,6 +141,59 @@ static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK = | ||||
|     {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK", | ||||
|      "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"}; | ||||
|  | ||||
| #if defined(__linux__) | ||||
| struct DumpPipe { | ||||
|   DumpPipe(int rank) { | ||||
|     std::string fileStem = | ||||
|         getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, ""); | ||||
|     if (fileStem.empty() || | ||||
|         getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) { | ||||
|       return; | ||||
|     } | ||||
|     TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty"); | ||||
|     std::string filename = c10::str(fileStem, rank, ".pipe"); | ||||
|     TORCH_CHECK( | ||||
|         unlink(filename.c_str()) != -1 || errno == ENOENT, | ||||
|         "Error removing existing named pipe ", | ||||
|         filename); | ||||
|     TORCH_CHECK( | ||||
|         mkfifo(filename.c_str(), 0666) != -1, | ||||
|         "Error creating named pipe ", | ||||
|         filename); | ||||
|     fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); | ||||
|     LOG(INFO) << "Pipe file " << filename | ||||
|               << " has been opened, write to it to trigger NCCL Debug Dump."; | ||||
|     TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename); | ||||
|   } | ||||
|   bool shouldDump() { | ||||
|     if (fd_ == -1) { | ||||
|       return false; | ||||
|     } | ||||
|     char buf[128]; | ||||
|     // non-blocking from O_NONBLOCK above. | ||||
|     // Ignore EINTR because we already will poll this | ||||
|     // again later. | ||||
|     ssize_t bytesRead = read(fd_, &buf, 128); | ||||
|     return bytesRead > 0; | ||||
|   } | ||||
|   ~DumpPipe() { | ||||
|     if (fd_ != -1) { | ||||
|       close(fd_); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   int fd_ = -1; | ||||
| }; | ||||
| #else | ||||
| struct DumpPipe { | ||||
|   DumpPipe(int rank) {} | ||||
|   bool shouldDump() { | ||||
|     return false; | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| // ProcessGroupNCCL implements NCCL bindings for c10d. | ||||
| // | ||||
| // All functions of the class are expected to be called in the same order | ||||
| @ -205,6 +289,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|     uint64_t getSequencenumber() const override; | ||||
|  | ||||
|     const std::string& logPrefix() const; | ||||
|  | ||||
|     // Helper function that sets an exception_ptr on the WorkNCCL object. | ||||
|     void setException(std::exception_ptr exception_ptr); | ||||
|  | ||||
| @ -358,6 +444,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|     // via `ncclCommSplit` | ||||
|     std::shared_ptr<ProcessGroupNCCL> split_from; | ||||
|     int64_t split_color{0}; | ||||
|     std::string group_name; | ||||
|   }; | ||||
|  | ||||
|   // If you wish to create multiple process groups, each with a potentially | ||||
| @ -540,8 +627,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   void enableCollectivesTiming() override; | ||||
|  | ||||
|   // Provide an API for users to define their own ways to store NCCL debug info. | ||||
|   void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer); | ||||
|   // Helper function for iteratively aborting communicators in the provided map | ||||
|   void abortCommsFromMap( | ||||
|       std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>& | ||||
|           ncclCommsMap, | ||||
|       c10::optional<std::string> abortReason); | ||||
|  | ||||
|   // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) | ||||
|   // instead of relying on ProcessGroupNCCL destructor. | ||||
| @ -694,6 +784,19 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   // Desync debug helper | ||||
|   void logWorkEnd(WorkNCCL& work); | ||||
|  | ||||
|   // Generates a prefix that is unique to this process group and rank, for | ||||
|   // disambiguating logs | ||||
|   std::string createLogPrefix() const; | ||||
|  | ||||
|   // Returns the unique prefix created in createLogPrefix | ||||
|   const std::string& logPrefix() const; | ||||
|  | ||||
|   // Returns the global rank of the device. This function assumes that users | ||||
|   // always create a default global process group(PG) which includes all | ||||
|   // devices. It is called in the constructor of ProcessGroupNCCL, so it always | ||||
|   // return the rank_ of the the very first PG created, aka, default global PG. | ||||
|   const int& globalRank() const; | ||||
|  | ||||
|  protected: | ||||
|   // Function that runs as part of a separate thread aside from watchdog | ||||
|   // thread because we need to check the heartbeat from watchdog thread | ||||
| @ -712,6 +815,19 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   // for dump completion. | ||||
|   std::future<bool> launchAsyncDebugDump(); | ||||
|  | ||||
|   // Helper to wait up to the specified timeout and then abandon the dump. | ||||
|   // Logs on timeout, and asserts the future's status is as expected. | ||||
|   void waitForDumpOrTimeout( | ||||
|       std::future<bool>& fut, | ||||
|       const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime, | ||||
|       size_t timeout_sec = 30); | ||||
|  | ||||
|   // A helper function to wait for a future to complete or timeout. | ||||
|   void waitForFutureOrTimeout( | ||||
|       std::future<bool>& fut, | ||||
|       const std::chrono::milliseconds& timeOutMilSec, | ||||
|       const std::string& futDescription); | ||||
|  | ||||
|   // When watchdog timeout, this function will be called and return debug info | ||||
|   // for users. For now we only get information from retrieveDesyncReport. | ||||
|   // We are working on enabling more useful debug information for watchdog | ||||
| @ -720,9 +836,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   static const int64_t kWatchdogThreadSleepMillis; | ||||
|  | ||||
|   // The store is used to broadcast the NCCL unique ID of rank 0. | ||||
|   // The store is used to broadcast the NCCL unique ID of rank 0. This store | ||||
|   // comes with prefix and it is different across ProcessGroup NCCL instances | ||||
|   // (aka, different ProcessGroups). | ||||
|   c10::intrusive_ptr<Store> store_; | ||||
|  | ||||
|   // Reference to the store without prefix so that keys are same across all | ||||
|   // ProcessGroup NCCL instances and (key, value) pairs written to the store are | ||||
|   // global. | ||||
|   c10::intrusive_ptr<Store> globalStore_; | ||||
|  | ||||
|   bool storeError_{false}; | ||||
|  | ||||
|   const c10::intrusive_ptr<Options> options_; | ||||
| @ -781,11 +904,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   std::mutex mutex_; | ||||
|  | ||||
|   // Heartbeat of watchdog thread. | ||||
|   uint64_t heartbeat_; | ||||
|   std::atomic_uint64_t heartbeat_; | ||||
|  | ||||
|   // The time interval used for deciding whether there is no watchdog heartbeat. | ||||
|   int heartbeatTimeoutInSec_; | ||||
|  | ||||
|   // Extra time of sleep when waiting for timeout dump to finish. | ||||
|   int waitTimeoutDumpInMilSec_; | ||||
|  | ||||
|   // Interval of check coordinated signals in ProcessGroupNCCL from other ranks | ||||
|   // e.g., trigger the dump of the debugging info for timeout when notified. | ||||
|   int coordCheckIntervalMilSec_; | ||||
|  | ||||
|   // Size of ring buffer where we store NCCL Traces for debugging. | ||||
|   int ncclTraceBufferSize_; | ||||
|  | ||||
| @ -815,6 +945,15 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   // Whether there are hooks pending to be fired | ||||
|   std::atomic<bool> hasPendingHooks_; | ||||
|  | ||||
|   // This is the signal from watchdog threads to indicate whether the monitor | ||||
|   // thread should dump. Making it static so that it is accessiable from all the | ||||
|   // PGs. With this flag, monitor thread would dump debug info under any one of | ||||
|   // the 3 conditions: 1: this flag is set to true by the watchdog thread when | ||||
|   // it detects a timeout. 2: timeout signal is received from | ||||
|   // other ranks through tcpstore 3: no heartbeat of watchdog Note that only the | ||||
|   // monitor thread from PG0 should dump the debug info and only once | ||||
|   static std::atomic<bool> shouldDump_; | ||||
|  | ||||
|   // Mutex to Guard workMetaList_ | ||||
|   std::mutex workMetaListMutex_; | ||||
|  | ||||
| @ -823,9 +962,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   bool writeDebugInfo_ = false; | ||||
|  | ||||
|   // Mutex to Guard the check of writeDebugInfo_ | ||||
|   std::mutex writeDebugInfoMutex_; | ||||
|  | ||||
|   // Condition Variable for watchdog thread sleep | ||||
|   std::condition_variable workMetaListCV_; | ||||
|  | ||||
| @ -902,8 +1038,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|   // Whether or not to enable timeout root cause analysis. | ||||
|   bool desyncDebug_; | ||||
|  | ||||
|   // Whether or not to dump debug info on timeout | ||||
|   bool dumpOnTimeout_; | ||||
|   // Whether or not to dump debug info on exception including both watchdog | ||||
|   // timeout and nccl errors. | ||||
|   bool dumpOnException_; | ||||
|  | ||||
|   // Whether or not to create start CUDAEvent and enable timing for start | ||||
|   // and end events. Note that enableTiming_ is always true if desyncDebug_ | ||||
| @ -929,14 +1066,25 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|  | ||||
|   std::exception_ptr watchDogException_ = nullptr; | ||||
|  | ||||
|   // The callback function to store NCCL debug info. | ||||
|   std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr; | ||||
|  | ||||
|   size_t uid_; | ||||
|  | ||||
|   std::string logPrefix_; | ||||
| }; | ||||
|  | ||||
| TORCH_API std::string dump_nccl_trace(); | ||||
|  | ||||
| // Gets a mutable reference to a global optional function.  Heartbeat Monitor | ||||
| // will query this function and if available, call it to dump traces. Inside | ||||
| // fbcode, we store a function here that uses an internal tool for process | ||||
| // tracing | ||||
| TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper(); | ||||
|  | ||||
| // Similar to get_cpp_trace_dumper, this stores a function defined in | ||||
| // torch-python layer that lets us check whether the GIL can be acquired, | ||||
| // helpful for instrumenting in cases where a hang was observed. | ||||
| typedef bool (*gil_checker_t)(); | ||||
|  | ||||
| TORCH_API gil_checker_t& get_gil_checker(); | ||||
| } // namespace c10d | ||||
|  | ||||
| #endif // USE_C10D_NCCL | ||||
|  | ||||
| @ -13,7 +13,6 @@ | ||||
| #include <string> | ||||
| #include <system_error> | ||||
| #include <vector> | ||||
|  | ||||
| namespace c10d { | ||||
|  | ||||
| /* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */ | ||||
| @ -269,10 +268,20 @@ inline std::string retrieveDesyncReport( | ||||
|  | ||||
| #ifdef USE_C10D_NCCL | ||||
|  | ||||
| DebugInfoWriter::DebugInfoWriter(int rank) { | ||||
|   std::string fileName = getCvarString( | ||||
|       {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|   filename_ = c10::str(fileName, rank); | ||||
| /* Helper used by work::getDuration() and nccl flight recorder */ | ||||
| float getDurationFromFirstEvent( | ||||
|     const std::vector<at::cuda::CUDAEvent>& ncclStartEvents, | ||||
|     const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) { | ||||
|   TORCH_CHECK( | ||||
|       ncclStartEvents.size() == 1, | ||||
|       "getDuration only works for single device per ProcessGroup, but found multiple start events."); | ||||
|   TORCH_CHECK( | ||||
|       ncclEndEvents.size() == 1, | ||||
|       "getDuration only works for single device per ProcessGroup, but found multiple end events."); | ||||
|   TORCH_CHECK( | ||||
|       ncclEndEvents[0].query(), | ||||
|       "getDuration can only be called after work is succeeded.") | ||||
|   return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]); | ||||
| } | ||||
|  | ||||
| DebugInfoWriter::~DebugInfoWriter() = default; | ||||
| @ -293,6 +302,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) { | ||||
|   LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; | ||||
| } | ||||
|  | ||||
| DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { | ||||
|   if (writer_ == nullptr) { | ||||
|     std::string fileNamePrefix = getCvarString( | ||||
|         {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); | ||||
|     // Using std::unique_ptr here to auto-delete the writer object | ||||
|     // when the pointer itself is destroyed. | ||||
|     std::unique_ptr<DebugInfoWriter> writerPtr( | ||||
|         new DebugInfoWriter(fileNamePrefix, rank)); | ||||
|     DebugInfoWriter::registerWriter(std::move(writerPtr)); | ||||
|   } | ||||
|   return *writer_; | ||||
| } | ||||
|  | ||||
| void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) { | ||||
|   TORCH_CHECK_WITH( | ||||
|       DistBackendError, | ||||
|       hasWriterRegistered_.load() == false, | ||||
|       "debugInfoWriter already registered"); | ||||
|   hasWriterRegistered_.store(true); | ||||
|   writer_ = std::move(writer); | ||||
| } | ||||
|  | ||||
| std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr; | ||||
| std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false); | ||||
|  | ||||
| inline std::string pickle_str(const c10::IValue& v) { | ||||
|   std::vector<char> result; | ||||
|   { | ||||
| @ -317,6 +351,18 @@ inline c10::List<c10::IValue> new_list() { | ||||
|   return c10::List<c10::IValue>(c10::AnyType::get()); | ||||
| } | ||||
|  | ||||
| inline std::string ranks_str(const std::vector<uint64_t>& ranks) { | ||||
|   std::string str; | ||||
|   for (const auto& rank : ranks) { | ||||
|     if (str.empty()) { | ||||
|       str = std::to_string(rank); | ||||
|     } else { | ||||
|       str += ", " + std::to_string(rank); | ||||
|     } | ||||
|   } | ||||
|   return c10::str("[", str, "]"); | ||||
| } | ||||
|  | ||||
| struct NCCLTraceBuffer { | ||||
|   static NCCLTraceBuffer* get() { | ||||
|     // intentionally leak on exit | ||||
| @ -336,11 +382,12 @@ struct NCCLTraceBuffer { | ||||
|                 // buffer this entry will be located to | ||||
|                 // update state information | ||||
|     size_t pg_id_; | ||||
|     std::string pg_name_; | ||||
|     size_t seq_id_; // as tracked by the process group | ||||
|     const char* profiling_name_; | ||||
|     std::string profiling_name_; | ||||
|  | ||||
|     std::shared_ptr<torch::CapturedTraceback> traceback_; | ||||
|     // we borrow pointser to start_ and end_ so we can query the state | ||||
|     // we borrow pointers to start_ and end_ so we can query the state | ||||
|     // on reporting. However, once the event is completed, the call | ||||
|     // to `complete` will clear these. | ||||
|     EventList *start_, *end_; | ||||
| @ -348,8 +395,18 @@ struct NCCLTraceBuffer { | ||||
|     // timestamp when the entry was created, likely close to the time the work | ||||
|     // was 'enqueued'- not necessarily started | ||||
|     c10::time_t time_created_; | ||||
|     c10::optional<float> duration_; | ||||
|  | ||||
|     const char* state_ = "scheduled"; | ||||
|     // timestamp when our CPU threads discovered that the kernel started. | ||||
|     // will always be _after_ it actually started, and can be very late | ||||
|     // if the watchdog thread got stuck on CUDA APIs. | ||||
|     c10::optional<c10::time_t> time_discovered_started_; | ||||
|  | ||||
|     // timestamp when our CPU threads discovered that the kernel completed. | ||||
|     // will always be _after_ it actually complated, and can be the same time | ||||
|     // as the discovery of the start if the watchdog thread is stuck on CUDA | ||||
|     // APIs | ||||
|     c10::optional<c10::time_t> time_discovered_completed_; | ||||
|  | ||||
|     // size information for input/output tensors | ||||
|     c10::SmallVector<int, 4> input_dims_; | ||||
| @ -369,8 +426,9 @@ struct NCCLTraceBuffer { | ||||
|  | ||||
|   c10::optional<size_t> record( | ||||
|       size_t pg_id, | ||||
|       const std::string& pg_name, | ||||
|       size_t seq_id, | ||||
|       const char* profiling_name, | ||||
|       std::string profiling_name, | ||||
|       const std::vector<at::Tensor>& inputs, | ||||
|       const std::vector<at::Tensor>& outputs, | ||||
|       EventList* start, | ||||
| @ -385,8 +443,9 @@ struct NCCLTraceBuffer { | ||||
|     auto te = Entry{ | ||||
|         id_, | ||||
|         pg_id, | ||||
|         pg_name, | ||||
|         seq_id, | ||||
|         profiling_name, | ||||
|         std::move(profiling_name), | ||||
|         std::move(traceback), | ||||
|         std::move(start), | ||||
|         std::move(end), | ||||
| @ -424,8 +483,8 @@ struct NCCLTraceBuffer { | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       if (started) { | ||||
|         r.state_ = "started"; | ||||
|       if (started && !r.time_discovered_started_) { | ||||
|         r.time_discovered_started_ = c10::getTime(); | ||||
|       } | ||||
|     } | ||||
|     if (r.end_ != nullptr) { | ||||
| @ -436,8 +495,8 @@ struct NCCLTraceBuffer { | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       if (completed) { | ||||
|         r.state_ = "completed"; | ||||
|       if (completed && !r.time_discovered_completed_) { | ||||
|         r.time_discovered_completed_ = c10::getTime(); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @ -456,35 +515,97 @@ struct NCCLTraceBuffer { | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   void retire_id(c10::optional<size_t> id) { | ||||
|   /* | ||||
|   Mark an Event as completed and free its events. | ||||
|  | ||||
|   This is called by the watchdog thread, and is asynchronous from the | ||||
|   perspective of the main thread. | ||||
|  | ||||
|   compute_duration defaults to true since retire_id is only called in the | ||||
|   watchdog thread, which is currently a place we call cuda APIs which may hang, | ||||
|   but care should be taken to avoid computing duration in any function that must | ||||
|   never hang. (timing must also be enabled for compute_duration - see | ||||
|   TORCH_NCCL_ENABLE_TIMING). | ||||
|   */ | ||||
|   void retire_id(c10::optional<size_t> id, bool compute_duration = true) { | ||||
|     if (!enabled_ || !id) { | ||||
|       return; | ||||
|     } | ||||
|     std::lock_guard<std::mutex> guard(mutex_); | ||||
|     auto& entry = entries_.at(*id % max_entries_); | ||||
|     if (entry.id_ == *id) { | ||||
|       update_state(entry); | ||||
|       entry.retired_ = true; | ||||
|       entry.start_ = entry.end_ = nullptr; | ||||
|  | ||||
|     bool can_compute_duration = false; | ||||
|     EventList* startEvents = nullptr; | ||||
|     EventList* endEvents = nullptr; | ||||
|     c10::optional<float> duration = c10::nullopt; | ||||
|  | ||||
|     std::unique_lock<std::mutex> guard(mutex_); | ||||
|  | ||||
|     Entry* entry = &entries_.at(*id % max_entries_); | ||||
|     if (entry->id_ == *id) { | ||||
|       update_state(*entry); | ||||
|  | ||||
|       if (compute_duration) { | ||||
|         can_compute_duration = entry->time_discovered_completed_.has_value() && | ||||
|             entry->start_ && entry->end_; | ||||
|         startEvents = entry->start_; | ||||
|         endEvents = entry->end_; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (can_compute_duration) { | ||||
|       // Compute duration without without holding the lock, because | ||||
|       // cudaEventDuration() can hang, and we need to acquire the lock before we | ||||
|       // can dump(), which we never want to block. | ||||
|       guard.unlock(); | ||||
|       duration = getDurationFromFirstEvent(*startEvents, *endEvents); | ||||
|       guard.lock(); | ||||
|  | ||||
|       // Refresh the entry pointer, see if the entry has been overwritten | ||||
|       entry = &entries_.at(*id % max_entries_); | ||||
|       if (entry->id_ != *id) { | ||||
|         LOG(INFO) | ||||
|             << "retire_id abandoned for id " << *id | ||||
|             << ", event was overwritten while waiting to compute duration."; | ||||
|         return; | ||||
|       } | ||||
|       if (duration.has_value()) { | ||||
|         entry->duration_ = duration.value(); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     entry->retired_ = true; | ||||
|     entry->start_ = entry->end_ = nullptr; | ||||
|   } | ||||
|  | ||||
|   std::string dump() { | ||||
|   std::string dump( | ||||
|       const c10::optional<std::unordered_map< | ||||
|           std::string, | ||||
|           std::unordered_map<std::string, std::string>>>& ncclDumpMap) { | ||||
|     auto result = dump_entries(); | ||||
|     auto entries = new_list(); | ||||
|     c10::IValue pg_id_s = "pg_id"; | ||||
|     c10::IValue seq_id_s = "seq_id"; | ||||
|     c10::IValue profiling_name_s = "profiling_name"; | ||||
|     c10::IValue input_sizes_s = "input_sizes"; | ||||
|     c10::IValue output_sizes_s = "output_sizes"; | ||||
|     c10::IValue time_created_s = "time_created_us"; | ||||
|     c10::IValue entries_key = "entries"; | ||||
|     c10::IValue nccl_comm_key = "nccl_comm_state"; | ||||
|     c10::IValue version_key = "version"; | ||||
|     // Update whenever changing contents or formatting of the dump | ||||
|     // (minor when adding fields, major when changing existing fields) | ||||
|     c10::IValue version_val = "1.1"; | ||||
|  | ||||
|     c10::IValue frames_s = "frames"; | ||||
|     c10::IValue state_s = "state"; | ||||
|     c10::IValue line_s = "line"; | ||||
|     c10::IValue name_s = "name"; | ||||
|     c10::IValue filename_s = "filename"; | ||||
|     c10::IValue retired_s = "retired"; | ||||
|     c10::IValue pg_id_key = "pg_id"; | ||||
|     c10::IValue pg_name_key = "process_group"; | ||||
|     c10::IValue seq_id_key = "seq_id"; | ||||
|     c10::IValue profiling_name_key = "profiling_name"; | ||||
|     c10::IValue input_sizes_key = "input_sizes"; | ||||
|     c10::IValue output_sizes_key = "output_sizes"; | ||||
|     c10::IValue time_created_key = "time_created_ns"; | ||||
|     c10::IValue duration_key = "duration_ms"; | ||||
|  | ||||
|     c10::IValue frames_key = "frames"; | ||||
|     c10::IValue state_key = "state"; | ||||
|     c10::IValue line_key = "line"; | ||||
|     c10::IValue name_key = "name"; | ||||
|     c10::IValue filename_key = "filename"; | ||||
|     c10::IValue retired_key = "retired"; | ||||
|     c10::IValue time_discovered_started_key = "time_discovered_started_ns"; | ||||
|     c10::IValue time_discovered_completed_key = "time_discovered_completed_ns"; | ||||
|  | ||||
|     std::vector<torch::CapturedTraceback*> tracebacks; | ||||
|     for (auto& e : result) { | ||||
| @ -494,9 +615,9 @@ struct NCCLTraceBuffer { | ||||
|     std::vector<c10::IValue> all_frames; | ||||
|     for (const auto& f : stracebacks.all_frames) { | ||||
|       auto d = new_dict(); | ||||
|       d.insert(name_s, f.funcname); | ||||
|       d.insert(filename_s, f.filename); | ||||
|       d.insert(line_s, int64_t(f.lineno)); | ||||
|       d.insert(name_key, f.funcname); | ||||
|       d.insert(filename_key, f.filename); | ||||
|       d.insert(line_key, int64_t(f.lineno)); | ||||
|       all_frames.emplace_back(std::move(d)); | ||||
|     } | ||||
|  | ||||
| @ -504,10 +625,14 @@ struct NCCLTraceBuffer { | ||||
|       auto& e = result.at(i); | ||||
|       auto& tb = stracebacks.tracebacks.at(i); | ||||
|       auto dict = new_dict(); | ||||
|       dict.insert(pg_id_s, int64_t(e.pg_id_)); | ||||
|       dict.insert(seq_id_s, int64_t(e.seq_id_)); | ||||
|       dict.insert(profiling_name_s, e.profiling_name_); | ||||
|       dict.insert(time_created_s, int64_t(e.time_created_ / 1000)); | ||||
|       dict.insert(pg_id_key, int64_t(e.pg_id_)); | ||||
|       dict.insert(pg_name_key, e.pg_name_); | ||||
|       dict.insert(seq_id_key, int64_t(e.seq_id_)); | ||||
|       dict.insert(profiling_name_key, e.profiling_name_); | ||||
|       dict.insert(time_created_key, int64_t(e.time_created_)); | ||||
|       if (e.duration_) { | ||||
|         dict.insert(duration_key, *e.duration_); | ||||
|       } | ||||
|  | ||||
|       auto it = e.sizes_.begin(); | ||||
|       auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) { | ||||
| @ -523,19 +648,55 @@ struct NCCLTraceBuffer { | ||||
|         return sizes; | ||||
|       }; | ||||
|  | ||||
|       dict.insert(input_sizes_s, read_sizes(e.input_dims_)); | ||||
|       dict.insert(output_sizes_s, read_sizes(e.output_dims_)); | ||||
|       dict.insert(state_s, e.state_); | ||||
|       dict.insert(retired_s, e.retired_); | ||||
|       dict.insert(input_sizes_key, read_sizes(e.input_dims_)); | ||||
|       dict.insert(output_sizes_key, read_sizes(e.output_dims_)); | ||||
|       if (e.time_discovered_completed_.has_value()) { | ||||
|         dict.insert(state_key, "completed"); | ||||
|       } else if (e.time_discovered_started_.has_value()) { | ||||
|         dict.insert(state_key, "started"); | ||||
|       } else { | ||||
|         dict.insert(state_key, "scheduled"); | ||||
|       } | ||||
|  | ||||
|       dict.insert( | ||||
|           time_discovered_started_key, | ||||
|           e.time_discovered_started_.has_value() | ||||
|               ? int64_t(*e.time_discovered_started_) | ||||
|               : c10::IValue()); | ||||
|       dict.insert( | ||||
|           time_discovered_completed_key, | ||||
|           e.time_discovered_completed_.has_value() | ||||
|               ? int64_t(*e.time_discovered_completed_) | ||||
|               : c10::IValue()); | ||||
|       dict.insert(retired_key, e.retired_); | ||||
|  | ||||
|       auto frames = new_list(); | ||||
|       for (int64_t frame : tb) { | ||||
|         frames.push_back(all_frames.at(frame)); | ||||
|       } | ||||
|       dict.insert(frames_s, frames); | ||||
|       dict.insert(frames_key, frames); | ||||
|       entries.push_back(dict); | ||||
|     } | ||||
|     return pickle_str(entries); | ||||
|     // convert ncclDumpMap into a dictionary | ||||
|     auto per_comm_dict = new_dict(); | ||||
|     if (ncclDumpMap.has_value()) { | ||||
|       for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) { | ||||
|         auto inner_dict = new_dict(); | ||||
|         for (const auto& [key, value] : ncclDump) { | ||||
|           inner_dict.insert(key, value); | ||||
|         } | ||||
|         per_comm_dict.insert(ncclId, inner_dict); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     auto dict = new_dict(); | ||||
|     dict.insert(entries_key, entries); | ||||
|     dict.insert(version_key, version_val); | ||||
|     if (per_comm_dict.size() > 0) { | ||||
|       dict.insert(nccl_comm_key, per_comm_dict); | ||||
|     } | ||||
|  | ||||
|     return pickle_str(dict); | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -50,6 +50,35 @@ | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| #ifdef USE_C10D_NCCL | ||||
|  | ||||
| bool acquire_gil() { | ||||
|   // basically if this function can acquire the gil, it will return quickly. | ||||
|   // if not, it will hang forever.  The idea is to call this from a thread | ||||
|   // wrapped in a future, and then check the future after a timeout, to | ||||
|   // determine whether we're facing gil contention. | ||||
|   if (Py_IsInitialized()) { | ||||
|     pybind11::gil_scoped_acquire gil; | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   // If we end up here, its probably still a "pass" from the perspective of | ||||
|   // checking whether python is stuck. but currently we don't check the return | ||||
|   // value of this function anyway, just check whether it returned quickly vs | ||||
|   // timing out.  Taking a long time is the main sign of trouble.  Fast return | ||||
|   // with true or with false is both OK from the perspective of debugging python | ||||
|   // hangs. | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| bool registerGilChecker() { | ||||
|   c10d::get_gil_checker() = &acquire_gil; | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| static bool registered = registerGilChecker(); | ||||
| #endif // USE_C10D_NCCL | ||||
|  | ||||
| // Wrapper to ensure GIL is released before destructing ProcessGroupGloo | ||||
| // TODO: move this somewhere more generally useful | ||||
| template <typename T> | ||||
| @ -1033,6 +1062,29 @@ Example:: | ||||
|     >>> store.add("first_key", 6) | ||||
|     >>> # Should return 7 | ||||
|     >>> store.get("first_key") | ||||
| )") | ||||
|           .def( | ||||
|               "check", | ||||
|               &::c10d::Store::check, | ||||
|               py::call_guard<py::gil_scoped_release>(), | ||||
|               R"( | ||||
| The call to check whether a given list of ``keys`` have value stored in | ||||
| the store. This call immediately returns in normal cases but still suffers | ||||
| from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed. | ||||
| Calling :meth:`~torch.distributed.store.check` with a list of keys that | ||||
| one wants to check whether stored in the store or not. | ||||
|  | ||||
| Arguments: | ||||
|     keys (lisr[str]): The keys to query whether stored in the store. | ||||
|  | ||||
| Example:: | ||||
|     >>> import torch.distributed as dist | ||||
|     >>> from datetime import timedelta | ||||
|     >>> # Using TCPStore as an example, other store types can also be used | ||||
|     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) | ||||
|     >>> store.add("first_key", 1) | ||||
|     >>> # Should return 7 | ||||
|     >>> store.check(["first_key"]) | ||||
| )") | ||||
|           .def( | ||||
|               "delete_key", | ||||
| @ -1404,7 +1456,11 @@ Arguments: | ||||
|       .def_property_readonly( | ||||
|           "underlying_store", | ||||
|           &::c10d::PrefixStore::getUnderlyingStore, | ||||
|           R"(Gets the underlying store object that PrefixStore wraps around.)"); | ||||
|           R"(Gets the underlying store object that PrefixStore wraps around.)") | ||||
|       .def_property_readonly( | ||||
|           "_underlying_non_prefix_store", | ||||
|           &::c10d::PrefixStore::getUnderlyingNonPrefixStore, | ||||
|           R"(Recursively to get the store before layers of wrapping with PrefixStore.)"); | ||||
|  | ||||
|   auto processGroup = | ||||
|       py::class_< | ||||
| @ -1807,6 +1863,15 @@ Arguments: | ||||
|               "group_name", | ||||
|               &::c10d::ProcessGroup::getGroupName, | ||||
|               "(Gets this process group name. It's cluster unique)") | ||||
|           .def( | ||||
|               "_set_group_desc", | ||||
|               &::c10d::ProcessGroup::setGroupDesc, | ||||
|               py::call_guard<py::gil_scoped_acquire>(), | ||||
|               "Sets the process group description. This is an internal C10D method, do not use.") | ||||
|           .def_property_readonly( | ||||
|               "group_desc", | ||||
|               &::c10d::ProcessGroup::getGroupDesc, | ||||
|               "Gets this process group description") | ||||
|           .def_property( | ||||
|               "bound_device_id", | ||||
|               &::c10d::ProcessGroup::getBoundDeviceId, | ||||
| @ -2387,7 +2452,9 @@ Example:: | ||||
|       .def_readwrite( | ||||
|           "split_from", &::c10d::ProcessGroupNCCL::Options::split_from) | ||||
|       .def_readwrite( | ||||
|           "split_color", &::c10d::ProcessGroupNCCL::Options::split_color); | ||||
|           "split_color", &::c10d::ProcessGroupNCCL::Options::split_color) | ||||
|       .def_readwrite( | ||||
|           "group_name", &::c10d::ProcessGroupNCCL::Options::group_name); | ||||
|  | ||||
| #endif | ||||
|  | ||||
|  | ||||
							
								
								
									
										448
									
								
								torch/distributed/_state_dict_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										448
									
								
								torch/distributed/_state_dict_utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,448 @@ | ||||
| import io | ||||
| import math | ||||
| from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| import torch.nn.functional as F | ||||
| from torch.distributed._functional_collectives import AsyncCollectiveTensor | ||||
|  | ||||
| if dist.is_available() or TYPE_CHECKING: | ||||
|     from torch.distributed import distributed_c10d | ||||
|     from torch.distributed._shard.sharded_tensor import ShardedTensor | ||||
|     from torch.distributed._tensor import DTensor, Replicate | ||||
|  | ||||
|  | ||||
| def _identity_func( | ||||
|     obj: torch.Tensor, | ||||
|     pg: Optional[dist.ProcessGroup], | ||||
|     device: Optional[torch.device], | ||||
|     companion_obj: Any, | ||||
| ) -> torch.Tensor: | ||||
|     return obj | ||||
|  | ||||
|  | ||||
| def _all_gather_sharded_tensor( | ||||
|     sharded_tensor: "ShardedTensor", | ||||
|     pg: Optional[dist.ProcessGroup] = None, | ||||
|     device: Optional[torch.device] = None, | ||||
| ) -> torch.Tensor: | ||||
|     if pg is None: | ||||
|         pg = distributed_c10d._get_default_group() | ||||
|     world_size = dist.get_world_size(pg) | ||||
|     shards = sharded_tensor.local_shards() | ||||
|     dim_0_size = sharded_tensor.size()[0]  # type: ignore[index] | ||||
|     tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr] | ||||
|     chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size | ||||
|     pg_device = ( | ||||
|         distributed_c10d._get_pg_default_device(pg) if device is None else device | ||||
|     ) | ||||
|     if shards: | ||||
|         local_tensor = shards[0].tensor.flatten() | ||||
|         if local_tensor.device.type != pg_device.type: | ||||
|             local_tensor = local_tensor.to(pg_device) | ||||
|         num_padding = chunk_size - local_tensor.numel() | ||||
|         if num_padding > 0: | ||||
|             local_tensor = F.pad(local_tensor, [0, num_padding]) | ||||
|     else: | ||||
|         local_tensor = torch.zeros( | ||||
|             chunk_size, dtype=sharded_tensor.dtype, device=pg_device | ||||
|         ) | ||||
|  | ||||
|     tensor = torch.empty( | ||||
|         chunk_size * world_size, | ||||
|         dtype=local_tensor.dtype, | ||||
|         device=pg_device, | ||||
|     ) | ||||
|     dist.all_gather_into_tensor(tensor, local_tensor, group=pg) | ||||
|  | ||||
|     tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) | ||||
|     return tensor | ||||
|  | ||||
|  | ||||
| class CompanionMismatch(Exception): | ||||
|     ... | ||||
|  | ||||
|  | ||||
| def _iterate_state_dict( | ||||
|     iter_object: Any, | ||||
|     sharded_tensor_func: Callable, | ||||
|     dtensor_func: Callable, | ||||
|     tensor_func: Callable, | ||||
|     *, | ||||
|     pg: Optional[dist.ProcessGroup] = None, | ||||
|     device: Optional[torch.device] = None, | ||||
|     cpu_offload: bool = False, | ||||
|     companion_obj: Any = None, | ||||
|     ranks_only: Tuple[int, ...] = tuple(), | ||||
|     type_check: bool = True, | ||||
|     non_blocking: bool = True, | ||||
| ) -> Dict[str, Any]: | ||||
|     """Iterate through the state dict, applying the given functions to each tensor type. | ||||
|  | ||||
|     Args: | ||||
|         iter_object (Any): the target state_dict. | ||||
|         sharded_tensor_func (Callable): the function to apply to ShardedTensor | ||||
|         dtensor_func (Callable): the function to apply to DTensor | ||||
|         tensor_func (Callable): the function to apply to Tensor | ||||
|         pg (Optional[dist.ProcessGroup]): process group passed to tensor functions | ||||
|         device (Optional[torch.device]): device passed to tensor functions | ||||
|         cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored | ||||
|             if a companion_obj is supplied. | ||||
|         companion_obj (Any): A companion object to the state dict. If this object | ||||
|             is supplied, we attempt to copy the tensor to the companion object. | ||||
|         ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will | ||||
|             have the same state_dicts. Otherwise only ranks that in ``ranks_only`` | ||||
|             have the same state_dicts. Other ranks will get empty state_dicts. | ||||
|         type_check (bool): check if the instance data type is a supported type | ||||
|             that can be saved by DCP.  The current supported data types are | ||||
|             torch.Tensor, DTensor, int, float, str, list, dict, None. | ||||
|         non_blocking (bool): whether to use non-blocking copy when copying to the companion object. | ||||
|     """ | ||||
|     # TODO: should we use pytree? | ||||
|     cpu_device = torch.device("cpu") | ||||
|     if isinstance(iter_object, ShardedTensor): | ||||
|         ret = sharded_tensor_func(iter_object, pg, device, companion_obj) | ||||
|     elif isinstance(iter_object, DTensor): | ||||
|         ret = dtensor_func(iter_object, pg, device, companion_obj) | ||||
|     elif isinstance(iter_object, torch.Tensor): | ||||
|         ret = tensor_func(iter_object, pg, device, companion_obj) | ||||
|     elif ( | ||||
|         isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) | ||||
|         or iter_object is None | ||||
|     ): | ||||
|         ret = iter_object | ||||
|     elif isinstance(iter_object, dict): | ||||
|         if companion_obj is not None and ( | ||||
|             not isinstance(companion_obj, dict) | ||||
|             or set(companion_obj.keys()) != set(iter_object.keys()) | ||||
|         ): | ||||
|             raise CompanionMismatch() | ||||
|  | ||||
|         ret = { | ||||
|             key: _iterate_state_dict( | ||||
|                 value, | ||||
|                 sharded_tensor_func, | ||||
|                 dtensor_func, | ||||
|                 tensor_func, | ||||
|                 pg=pg, | ||||
|                 device=device, | ||||
|                 cpu_offload=cpu_offload, | ||||
|                 companion_obj=companion_obj[key] if companion_obj is not None else None, | ||||
|                 ranks_only=ranks_only, | ||||
|                 type_check=type_check, | ||||
|                 non_blocking=non_blocking, | ||||
|             ) | ||||
|             for key, value in iter_object.items() | ||||
|         } | ||||
|     elif isinstance(iter_object, (list, tuple)): | ||||
|         if companion_obj is not None and ( | ||||
|             not isinstance(companion_obj, (list, tuple)) | ||||
|             or len(companion_obj) != len(iter_object) | ||||
|         ): | ||||
|             raise CompanionMismatch() | ||||
|  | ||||
|         ret = [ | ||||
|             _iterate_state_dict( | ||||
|                 v, | ||||
|                 sharded_tensor_func, | ||||
|                 dtensor_func, | ||||
|                 tensor_func, | ||||
|                 pg=pg, | ||||
|                 device=device, | ||||
|                 cpu_offload=cpu_offload, | ||||
|                 companion_obj=companion_obj[idx] if companion_obj is not None else None, | ||||
|                 ranks_only=ranks_only, | ||||
|                 type_check=type_check, | ||||
|                 non_blocking=non_blocking, | ||||
|             ) | ||||
|             for idx, v in enumerate(iter_object) | ||||
|         ] | ||||
|         if isinstance(iter_object, tuple): | ||||
|             ret = tuple(ret) | ||||
|     elif not type_check: | ||||
|         ret = iter_object | ||||
|     else: | ||||
|         raise ValueError(f"Unexpected value type {type(iter_object)}") | ||||
|  | ||||
|     if not ranks_only or dist.get_rank(pg) in ranks_only: | ||||
|         if isinstance(ret, torch.Tensor): | ||||
|             if cpu_offload and companion_obj is None: | ||||
|                 ret = ret.to(cpu_device) | ||||
|  | ||||
|             if companion_obj is not None: | ||||
|                 # TODO: support DTensor | ||||
|                 companion_obj.copy_(ret, non_blocking=non_blocking) | ||||
|                 ret = companion_obj | ||||
|     else: | ||||
|         ret = {} if isinstance(ret, dict) else None | ||||
|  | ||||
|     return ret | ||||
|  | ||||
|  | ||||
| def _gather_state_dict( | ||||
|     state_dict: Dict[str, Any], | ||||
|     *, | ||||
|     pg: Optional[dist.ProcessGroup] = None, | ||||
|     device: Optional[torch.device] = None, | ||||
|     cpu_offload: bool = False, | ||||
|     ranks_only: Tuple[int, ...] = tuple(), | ||||
|     type_check: bool = True, | ||||
| ) -> Dict[str, Any]: | ||||
|     """ | ||||
|     Given a state_dict, this API gathers all the ShardedTensors or DTensors in | ||||
|     the state_dict. | ||||
|  | ||||
|  | ||||
|     Args: | ||||
|         state_dict (Dict[str, Any]): the target sharded state_dict. | ||||
|         pg (Optional[dist.ProcessGroup]): the process group that is used to | ||||
|             gather ShardedTensor. Note that gathering a DTensor will use | ||||
|             the DeviceMesh. So this argument will be ignored when gathering a | ||||
|             DTensor. | ||||
|         device: (Optional[torch.device]): the device that is used to | ||||
|             perform allgather for ShardedTensor. Note that gathering a DTensor | ||||
|             will use the DeviceMesh. So this argument will be ignored when | ||||
|             gathering a DTensor. | ||||
|         cpu_offload (bool): whether to offload the tensors to CPU memory. The | ||||
|             default value is False. | ||||
|         ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will | ||||
|             have the same state_dicts. Otherwise only ranks that in ``ranks_only`` | ||||
|             have the same state_dicts. Other ranks will get empty state_dicts. | ||||
|         type_check: (bool): check if the instance data type is a supported type | ||||
|             that can be saved by DCP.  The current supported data types are | ||||
|             torch.Tensor, DTensor, int, float, str, list, dict, None. | ||||
|  | ||||
|     Returns: | ||||
|         The gathered state dictionary. | ||||
|     """ | ||||
|  | ||||
|     def sharded_tensor_func(value, pg, device, companion_obj): | ||||
|         # ShardedTensor does not seem to record the original device type. | ||||
|         # So if the tensor is moved to CPU, we won't know the original type. | ||||
|         # As a result, we have to rely on the user to tell us the correct one. | ||||
|         cpu_device = torch.device("cpu") | ||||
|         output_tensor = _all_gather_sharded_tensor(value, pg, device) | ||||
|         local_shard_device = ( | ||||
|             value.local_shards()[0].tensor.device | ||||
|             if value.local_shards() | ||||
|             else cpu_device | ||||
|         ) | ||||
|         if output_tensor.device != local_shard_device: | ||||
|             value = output_tensor.to(local_shard_device) | ||||
|         else: | ||||
|             value = output_tensor | ||||
|         return value | ||||
|  | ||||
|     def dtensor_func(value, pg, device, companion_obj): | ||||
|         if value.device != value.device_mesh.device_type: | ||||
|             value = value.to(value.device_mesh.device_type) | ||||
|         # FSDP all_gather: [Shard(0)] -> [Replicate()] | ||||
|         # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] | ||||
|         # 2D FSDP + TP all_gather: | ||||
|         # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] | ||||
|         # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] | ||||
|         placements = [Replicate() for _ in value.placements] | ||||
|         value = value.redistribute( | ||||
|             device_mesh=value.device_mesh, | ||||
|             placements=placements, | ||||
|         ) | ||||
|         # Call `wait()` to force the tensor to be synchronous with respect | ||||
|         # to the main stream. | ||||
|         # See the discussion in https://github.com/pytorch/pytorch/pull/117799. | ||||
|         value = value.to_local() | ||||
|         if isinstance(value, AsyncCollectiveTensor): | ||||
|             value = value.wait() | ||||
|         return value | ||||
|  | ||||
|     return _iterate_state_dict( | ||||
|         state_dict, | ||||
|         sharded_tensor_func, | ||||
|         dtensor_func, | ||||
|         _identity_func, | ||||
|         pg=pg, | ||||
|         device=device, | ||||
|         cpu_offload=cpu_offload, | ||||
|         ranks_only=ranks_only, | ||||
|         type_check=type_check, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _offload_state_dict_to_cpu( | ||||
|     state_dict: Dict[str, Any], | ||||
|     *, | ||||
|     ranks_only: Tuple[int, ...] = tuple(), | ||||
|     type_check: bool = True, | ||||
| ) -> Dict[str, Any]: | ||||
|     """ | ||||
|     Given a state_dict, this API offload all the tensors to CPU memory. | ||||
|  | ||||
|     Args: | ||||
|         state_dict (Dict[str, Any]): the target state_dict. | ||||
|         pg (Optional[dist.ProcessGroup]): the process group that is used to | ||||
|             gather ShardedTensor. Note that gathering a DTensor will use | ||||
|             the DeviceMesh. So this argument will be ignored when gathering a | ||||
|             DTensor. | ||||
|         ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will | ||||
|             have the same state_dicts. Otherwise only ranks that in ``ranks_only`` | ||||
|             have the same state_dicts. Other ranks will get empty state_dicts. | ||||
|         type_check: (bool): check if the instance data type is a supported type | ||||
|             that can be saved by DCP.  The current supported data types are | ||||
|             torch.Tensor, DTensor, int, float, str, list, dict, None. | ||||
|  | ||||
|     Returns: | ||||
|         The gathered state dictionary. | ||||
|     """ | ||||
|  | ||||
|     ret = _iterate_state_dict( | ||||
|         state_dict, | ||||
|         _identity_func, | ||||
|         _identity_func, | ||||
|         _identity_func, | ||||
|         pg=None, | ||||
|         device=None, | ||||
|         cpu_offload=True, | ||||
|         ranks_only=ranks_only, | ||||
|         type_check=type_check, | ||||
|     ) | ||||
|     return ret | ||||
|  | ||||
|  | ||||
| def _copy_state_dict( | ||||
|     state_dict: Dict[str, Any], | ||||
|     copy_state_dict: Dict[str, Any], | ||||
|     non_blocking: bool = False, | ||||
| ): | ||||
|     """ | ||||
|     Copies all tensors in a given state dict into a different state_dict with the | ||||
|     same structure. | ||||
|  | ||||
|     .. warning:: | ||||
|         It is expected by this function that state_dict and copy_state_dict share | ||||
|         the same structure and data types. | ||||
|  | ||||
|     .. warning:: | ||||
|         The current supported data types are | ||||
|             torch.Tensor, DTensor, int, float, str, list, dict, None. | ||||
|  | ||||
|     Args: | ||||
|         state_dict (Dict[str, Any]): the target state_dict. | ||||
|         copy_state_dict (Dict[str, Any]): | ||||
|             The state dict we are copying into. This state_dict must have exactly | ||||
|              the same structure as the source `state_dict`. | ||||
|         non_blocking: (bool): Whether copy ops should be performed asynchronously | ||||
|     """ | ||||
|  | ||||
|     _iterate_state_dict( | ||||
|         state_dict, | ||||
|         _identity_func, | ||||
|         _identity_func, | ||||
|         _identity_func, | ||||
|         pg=None, | ||||
|         device=None, | ||||
|         cpu_offload=False, | ||||
|         ranks_only=tuple(), | ||||
|         companion_obj=copy_state_dict, | ||||
|         type_check=True, | ||||
|         non_blocking=non_blocking, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _create_cpu_state_dict( | ||||
|     state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False | ||||
| ) -> Dict[str, Any]: | ||||
|     """ | ||||
|     Given a state_dict, create another state_dict with the same structure and elements. | ||||
|     However, all tensors in the returned state_dict are new tensors on CPU. These | ||||
|     tensors can be placed on pin_memory or share_memory based on the provided arguments. | ||||
|  | ||||
|     .. warning:: | ||||
|         Setting both `pin_memory` and `share_memory` to True significantly increases the | ||||
|         latency of this method because of the nuances which require us to register memory | ||||
|         as pinned directly as opposed to relying on the pin_memory cache allocator. This | ||||
|         option should only be used for long lived tensors which are required to be shared. | ||||
|         This is not the case as long as at least one of `pin_memory` or `share_memory` is | ||||
|          set to False. | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def tensor_func( | ||||
|         obj: torch.Tensor, | ||||
|         pg: Optional[dist.ProcessGroup], | ||||
|         device: Optional[torch.device], | ||||
|         _: Any, | ||||
|     ) -> torch.Tensor: | ||||
|         if len(obj.size()) == 0: | ||||
|             return torch.tensor(0, dtype=obj.dtype) | ||||
|  | ||||
|         if share_memory: | ||||
|             t = torch.empty(*tuple(obj.size()), dtype=obj.dtype).share_memory_() | ||||
|             if pin_memory: | ||||
|                 succ = torch.cuda.cudart().cudaHostRegister( | ||||
|                     t.data_ptr(), | ||||
|                     t.numel() * t.element_size(), | ||||
|                     1,  # lines up with 'cudaHostRegisterPortable' | ||||
|                 ) | ||||
|                 assert ( | ||||
|                     succ == 0 | ||||
|                 ), f"Pinning shared memory failed with error-code: {succ}" | ||||
|             return t | ||||
|         elif pin_memory: | ||||
|             return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() | ||||
|         else: | ||||
|             return torch.empty(*tuple(obj.size()), dtype=obj.dtype) | ||||
|  | ||||
|     ret = _iterate_state_dict( | ||||
|         state_dict, | ||||
|         _identity_func, | ||||
|         _identity_func, | ||||
|         tensor_func, | ||||
|         pg=None, | ||||
|         device=None, | ||||
|         cpu_offload=False, | ||||
|         ranks_only=tuple(), | ||||
|         type_check=False, | ||||
|     ) | ||||
|     return ret | ||||
|  | ||||
|  | ||||
| def _check_state_dict_similarity( | ||||
|     state_dict: Dict[str, Any], | ||||
|     compared_state_dict: Dict[str, Any], | ||||
| ) -> bool: | ||||
|     """ | ||||
|     Given two state_dicts, check if the structures are the same. And | ||||
|     if a [key, tensor] pair exist in one state_dict there must be | ||||
|     the a corresponding pait, [key, other_tensor], in the other state_dict, | ||||
|     where tensor and other_tensor have the same size and dtype. | ||||
|  | ||||
|     Return the check result. | ||||
|     """ | ||||
|  | ||||
|     def tensor_func( | ||||
|         obj: torch.Tensor, | ||||
|         pg: Optional[dist.ProcessGroup], | ||||
|         device: Optional[torch.device], | ||||
|         companion_obj: Any, | ||||
|     ) -> torch.Tensor: | ||||
|         if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): | ||||
|             raise CompanionMismatch() | ||||
|         return obj | ||||
|  | ||||
|     try: | ||||
|         _iterate_state_dict( | ||||
|             state_dict, | ||||
|             _identity_func, | ||||
|             _identity_func, | ||||
|             tensor_func, | ||||
|             pg=None, | ||||
|             device=None, | ||||
|             cpu_offload=False, | ||||
|             ranks_only=tuple(), | ||||
|             companion_obj=compared_state_dict, | ||||
|             type_check=False, | ||||
|         ) | ||||
|     except CompanionMismatch: | ||||
|         return False | ||||
|  | ||||
|     return True | ||||
| @ -1171,7 +1171,7 @@ def init_process_group( | ||||
|             ) | ||||
|  | ||||
|         default_pg, _ = _new_process_group_helper( | ||||
|             -1, -1, [], backend, None, group_name, timeout=timeout | ||||
|             -1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg" | ||||
|         ) | ||||
|         _update_default_pg(default_pg) | ||||
|     else: | ||||
| @ -1197,6 +1197,7 @@ def init_process_group( | ||||
|             pg_options=pg_options, | ||||
|             timeout=timeout, | ||||
|             device_id=device_id, | ||||
|             group_desc="default_pg" | ||||
|         ) | ||||
|         _update_default_pg(default_pg) | ||||
|  | ||||
| @ -1257,6 +1258,7 @@ def _new_process_group_helper( | ||||
|     timeout=None, | ||||
|     pg_tag=None, | ||||
|     device_id=None, | ||||
|     group_desc=None, | ||||
| ): | ||||
|     """ | ||||
|     Create a new distributed process group. | ||||
| @ -1289,6 +1291,8 @@ def _new_process_group_helper( | ||||
|             _, prefix_store = _world.pg_map[existing_group] | ||||
|             return existing_group, prefix_store | ||||
|  | ||||
|     group_desc = "undefined" if group_desc is None else group_desc | ||||
|  | ||||
|     # The list of group ranks is empty if we're creating the default group. | ||||
|     is_default_group = len(global_ranks_in_group) == 0 | ||||
|  | ||||
| @ -1375,6 +1379,7 @@ def _new_process_group_helper( | ||||
|             if split_from: | ||||
|                 pg_options.split_from = split_from | ||||
|                 pg_options.split_color = _process_group_color(global_ranks_in_group) | ||||
|             pg_options.group_name = group_name | ||||
|             backend_class = ProcessGroupNCCL( | ||||
|                 backend_prefix_store, group_rank, group_size, pg_options) | ||||
|             backend_type = ProcessGroup.BackendType.NCCL | ||||
| @ -1461,9 +1466,11 @@ def _new_process_group_helper( | ||||
|  | ||||
|     # update global state | ||||
|     assert group_name is not None | ||||
|     assert group_desc is not None | ||||
|     _world.pg_map[pg] = (backend, prefix_store) | ||||
|     _world.pg_names[pg] = group_name | ||||
|     pg._set_group_name(group_name) | ||||
|     pg._set_group_desc(group_desc) | ||||
|  | ||||
|     _world.pg_backend_config[pg] = str(backend_config) | ||||
|     # "" is the default tag for user PGs | ||||
| @ -3614,7 +3621,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend: | ||||
|  | ||||
|  | ||||
| @_time_logger | ||||
| def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False): | ||||
| def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None): | ||||
|     """ | ||||
|     Create a new distributed group. | ||||
|  | ||||
| @ -3655,6 +3662,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local | ||||
|             barrier at the end of the process group creation. This is different | ||||
|             in that non-member ranks don't need to call into API and don't | ||||
|             join the barrier. | ||||
|         group_desc (str, optional): a string to describe the process group. | ||||
|  | ||||
|     Returns: | ||||
|         A handle of distributed group that can be given to collective calls or None if the rank is not part of ``ranks``. | ||||
| @ -3669,7 +3677,15 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local | ||||
|     multiple overlaping process groups. To avoid that, make sure all ranks follow the | ||||
|     same global creation order. | ||||
|     """ | ||||
|     return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization) | ||||
|     return _new_group_with_tag( | ||||
|         ranks, | ||||
|         timeout, | ||||
|         backend, | ||||
|         pg_options, | ||||
|         None, | ||||
|         use_local_synchronization=use_local_synchronization, | ||||
|         group_desc=group_desc, | ||||
|     ) | ||||
|  | ||||
| def _new_group_with_tag( | ||||
|     ranks=None, | ||||
| @ -3677,7 +3693,8 @@ def _new_group_with_tag( | ||||
|     backend=None, | ||||
|     pg_options=None, | ||||
|     pg_tag=None, | ||||
|     use_local_synchronization=False | ||||
|     use_local_synchronization=False, | ||||
|     group_desc=None | ||||
| ): | ||||
|     """ | ||||
|     Variant of ``new_group`` that exposes tag creation. | ||||
| @ -3749,7 +3766,8 @@ def _new_group_with_tag( | ||||
|         group_name, | ||||
|         pg_options=pg_options, | ||||
|         timeout=timeout, | ||||
|         pg_tag=pg_tag | ||||
|         pg_tag=pg_tag, | ||||
|         group_desc=group_desc | ||||
|     ) | ||||
|  | ||||
|     # Create the global rank to group rank mapping | ||||
| @ -3789,6 +3807,7 @@ def new_subgroups( | ||||
|     timeout=None, | ||||
|     backend=None, | ||||
|     pg_options=None, | ||||
|     group_desc=None, | ||||
| ): | ||||
|     """ | ||||
|     Create subgroups of equal size. | ||||
| @ -3841,6 +3860,8 @@ def new_subgroups( | ||||
|             the construction of specific process groups. i.e. for the ``nccl`` | ||||
|             backend, ``is_high_priority_stream`` can be specified so that | ||||
|             process group can pick up high priority cuda streams. | ||||
|         group_desc (str, optional): A string describing the group. Each subgroup will | ||||
|             inherit its group_desc | ||||
|  | ||||
|     Returns: | ||||
|         The subgroup containing the current rank, and all the subgroups used for cleanup. | ||||
| @ -3886,6 +3907,7 @@ def new_subgroups( | ||||
|             timeout=timeout, | ||||
|             backend=backend, | ||||
|             pg_options=pg_options, | ||||
|             group_desc=group_desc, | ||||
|         ) | ||||
|         subgroups.append(subgroup) | ||||
|  | ||||
| @ -3905,6 +3927,7 @@ def new_subgroups_by_enumeration( | ||||
|     timeout=None, | ||||
|     backend=None, | ||||
|     pg_options=None, | ||||
|     group_desc=None, | ||||
| ): | ||||
|     """ | ||||
|     Create subgroups by dividing the global world. | ||||
| @ -3945,6 +3968,8 @@ def new_subgroups_by_enumeration( | ||||
|             the construction of specific process groups. i.e. for the ``nccl`` | ||||
|             backend, ``is_high_priority_stream`` can be specified so that | ||||
|             process group can pick up high priority cuda streams. | ||||
|         group_desc (str, optional): A string describing the group. Each subgroup will | ||||
|             inherit its group_desc. | ||||
|  | ||||
|     Returns: | ||||
|         The subgroup containing the current rank, and all the subgroups used for cleanup. | ||||
| @ -3973,6 +3998,7 @@ def new_subgroups_by_enumeration( | ||||
|             timeout=timeout, | ||||
|             backend=backend, | ||||
|             pg_options=pg_options, | ||||
|             group_desc=group_desc, | ||||
|         ) | ||||
|         subgroups.append(subgroup) | ||||
|         my_rank = get_rank() | ||||
|  | ||||
| @ -28,7 +28,7 @@ def tail_logfile( | ||||
|             return | ||||
|         time.sleep(interval_sec) | ||||
|  | ||||
|     with open(file) as fp: | ||||
|     with open(file, errors="replace") as fp: | ||||
|         while True: | ||||
|             line = fp.readline() | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	