mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			56 Commits
		
	
	
		
			malfet-pat
			...
			flight_5.1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b86edd97d6 | |||
| b33a283e9a | |||
| 7a551d81e5 | |||
| 1515a90475 | |||
| 4882ec2a91 | |||
| 972b8060bd | |||
| 3e7683ae18 | |||
| f2e9ec2dc5 | |||
| dde4324d8e | |||
| 94c079104d | |||
| a6afee6d94 | |||
| d092857531 | |||
| 6aad5e444a | |||
| c54ce9313b | |||
| 1fe59f4ef7 | |||
| e693fb2bb1 | |||
| 4fe510baf6 | |||
| 7c507b78c4 | |||
| 0019901601 | |||
| 18be18535b | |||
| 2729367313 | |||
| 33537aae24 | |||
| dcdb1337dd | |||
| 9cf0f2bd59 | |||
| 1d2e877c05 | |||
| f27b979b0c | |||
| f30d6047ad | |||
| 75311510ef | |||
| dbd6094d05 | |||
| 397b9d47e9 | |||
| 36a01a8ab9 | |||
| ee336cf58a | |||
| a9e2e745d7 | |||
| ab4df89eea | |||
| 9d02ebe876 | |||
| b61e01cce9 | |||
| f7ce61ba53 | |||
| e7bae15ab1 | |||
| e71b422908 | |||
| 389940ce60 | |||
| b2237a7c85 | |||
| ef5dfe3f3e | |||
| e303dc3c08 | |||
| 265efad2de | |||
| 60f0455905 | |||
| 4898313791 | |||
| f4da9adf6b | |||
| 8f7f35273e | |||
| 44ec9612ed | |||
| 4d3bea2b29 | |||
| 0bcdddc3c1 | |||
| 28b6220312 | |||
| 210b7b65e2 | |||
| 4da10b5cd3 | |||
| f09763814f | |||
| 80923ed5a6 | 
@ -1732,7 +1732,7 @@ if(BUILD_TEST)
 | 
			
		||||
  foreach(test_src ${Caffe2_CPU_TEST_SRCS})
 | 
			
		||||
    get_filename_component(test_name ${test_src} NAME_WE)
 | 
			
		||||
    add_executable(${test_name} "${test_src}")
 | 
			
		||||
    target_link_libraries(${test_name} torch_library gtest_main)
 | 
			
		||||
    target_link_libraries(${test_name} torch_library gtest_main stdc++)
 | 
			
		||||
    target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
 | 
			
		||||
    target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
 | 
			
		||||
    target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
 | 
			
		||||
 | 
			
		||||
@ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) {
 | 
			
		||||
// Extend the nested class outside the parent class
 | 
			
		||||
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  TestDebugInfoWriter() : DebugInfoWriter(0) {}
 | 
			
		||||
  TestDebugInfoWriter(std::string namePrefix)
 | 
			
		||||
      : DebugInfoWriter(namePrefix, 0) {}
 | 
			
		||||
 | 
			
		||||
  void write(const std::string& ncclTrace) override {
 | 
			
		||||
    traces_.assign(ncclTrace.begin(), ncclTrace.end());
 | 
			
		||||
@ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
 | 
			
		||||
  // The storer here is very similar to the fallback storer.
 | 
			
		||||
  // The only difference is that we are storing traces also in memory for
 | 
			
		||||
  // validation.
 | 
			
		||||
  std::string fileNamePrefix = c10d::getCvarString(
 | 
			
		||||
      {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
  std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
 | 
			
		||||
      std::make_unique<TestDebugInfoWriter>();
 | 
			
		||||
      std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
 | 
			
		||||
  std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
 | 
			
		||||
  pg.registerDebugInfoWriter(std::move(wrterForTestPtr));
 | 
			
		||||
  c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
 | 
			
		||||
 | 
			
		||||
  // Normal collective case.
 | 
			
		||||
  auto work = pg.allreduce(tensors_);
 | 
			
		||||
@ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
 | 
			
		||||
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
 | 
			
		||||
 protected:
 | 
			
		||||
  void SetUp() override {
 | 
			
		||||
    // TODO (kwen2501)
 | 
			
		||||
    GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
 | 
			
		||||
                 << "will rewrite them after refactoring Work queues.";
 | 
			
		||||
    ProcessGroupNCCLErrorsTest::SetUp();
 | 
			
		||||
    std::string timeInterval = std::to_string(heartBeatIntervalInSec);
 | 
			
		||||
    ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
 | 
			
		||||
 | 
			
		||||
@ -4,9 +4,10 @@ import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.distributed._functional_collectives as funcol
 | 
			
		||||
 | 
			
		||||
from torch.distributed._tensor import DTensor
 | 
			
		||||
from torch.distributed._tensor.placement_types import Shard
 | 
			
		||||
from torch.distributed.checkpoint._state_dict_utils import (
 | 
			
		||||
from torch.distributed._state_dict_utils import (
 | 
			
		||||
    _check_state_dict_similarity,
 | 
			
		||||
    _copy_state_dict,
 | 
			
		||||
    _create_cpu_state_dict,
 | 
			
		||||
    _gather_state_dict,
 | 
			
		||||
    _offload_state_dict_to_cpu,
 | 
			
		||||
)
 | 
			
		||||
@ -115,6 +116,58 @@ class TestStateDictUtils(DTensorTestBase):
 | 
			
		||||
        }
 | 
			
		||||
        self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
 | 
			
		||||
 | 
			
		||||
    @skip_if_lt_x_gpu(2)
 | 
			
		||||
    def test_create_cpu_state_dict(self):
 | 
			
		||||
        device = torch.device("cuda")
 | 
			
		||||
        buffer = io.BytesIO()
 | 
			
		||||
        torch.save(torch.ones(10), buffer)
 | 
			
		||||
        buffer.seek(0)
 | 
			
		||||
        state_dict = {
 | 
			
		||||
            "tensor1": torch.arange(10, device=device),
 | 
			
		||||
            "tensor2": torch.ones(10, device=device),
 | 
			
		||||
            "non_tensor_bytes_io": copy.deepcopy(buffer),
 | 
			
		||||
            "non_tensor_bytes": buffer.read(),
 | 
			
		||||
            "step": torch.tensor(7, dtype=torch.float),
 | 
			
		||||
            "lr": 1.5,
 | 
			
		||||
            "nested": {"list": [1, 2, 3, 4]},
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        def _verify(cpu_state_dict):
 | 
			
		||||
            # Verify the correctness of _check_state_dict_similarity()
 | 
			
		||||
            self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict))
 | 
			
		||||
            tensor1 = cpu_state_dict["tensor1"]
 | 
			
		||||
            cpu_state_dict["tensor1"] = torch.arange(11)
 | 
			
		||||
            self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict))
 | 
			
		||||
            cpu_state_dict["tensor1"] = tensor1
 | 
			
		||||
 | 
			
		||||
            _copy_state_dict(state_dict, cpu_state_dict)
 | 
			
		||||
 | 
			
		||||
            # Verify if _copy_state_dict works
 | 
			
		||||
            for v in cpu_state_dict.values():
 | 
			
		||||
                if isinstance(v, torch.Tensor):
 | 
			
		||||
                    self.assertFalse(v.is_cuda)
 | 
			
		||||
            self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
 | 
			
		||||
            self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
 | 
			
		||||
            buffer.seek(0)
 | 
			
		||||
            cpu_state_dict["non_tensor_bytes_io"].seek(0)
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read()
 | 
			
		||||
            )
 | 
			
		||||
            buffer.seek(0)
 | 
			
		||||
            self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read())
 | 
			
		||||
            self.assertEqual(cpu_state_dict["lr"], 1.5)
 | 
			
		||||
            self.assertEqual(cpu_state_dict["step"], 7)
 | 
			
		||||
            self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]})
 | 
			
		||||
 | 
			
		||||
        cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
 | 
			
		||||
        _verify(cpu_state_dict)
 | 
			
		||||
        cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True)
 | 
			
		||||
        _verify(cpu_state_dict)
 | 
			
		||||
        cpu_state_dict = _create_cpu_state_dict(
 | 
			
		||||
            state_dict, share_memory=True, pin_memory=True
 | 
			
		||||
        )
 | 
			
		||||
        _verify(cpu_state_dict)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_tests()
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ import tempfile
 | 
			
		||||
import threading
 | 
			
		||||
import pickle
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
import warnings
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
@ -1334,6 +1335,19 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
 | 
			
		||||
        self.assertEqual(tensor, original_tensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_set_process_group_desc(self):
 | 
			
		||||
        store = c10d.FileStore(self.file_name, self.world_size)
 | 
			
		||||
        device = torch.device(f'cuda:{self.rank}')
 | 
			
		||||
        pg_default = self._create_process_group_nccl(store, self.opts(), device_id=device)
 | 
			
		||||
        self.assertEqual(pg_default.group_desc, "default_pg")
 | 
			
		||||
        pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
 | 
			
		||||
        self.assertEqual(pg_1.group_desc, "test_purpose")
 | 
			
		||||
        pg_2 = c10d.new_group([0, 1])
 | 
			
		||||
        self.assertEqual(pg_2.group_desc, "undefined")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedDataParallelTest(
 | 
			
		||||
    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
 | 
			
		||||
):
 | 
			
		||||
@ -3542,11 +3556,12 @@ class SparseCollective(MultiProcessTestCase):
 | 
			
		||||
class NCCLTraceTestBase(MultiProcessTestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'
 | 
			
		||||
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'  # see 'timing_enabled' parametrized tests
 | 
			
		||||
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
 | 
			
		||||
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
 | 
			
		||||
        self.tempdir = tempfile.TemporaryDirectory()
 | 
			
		||||
        os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
 | 
			
		||||
        os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
 | 
			
		||||
        self._spawn_processes()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -3617,28 +3632,50 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_short(self):
 | 
			
		||||
    @parametrize("timing_enabled", [True, False])
 | 
			
		||||
    def test_short(self, timing_enabled):
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            return
 | 
			
		||||
        pg = self._create_process_group_nccl()
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            pg._enable_collectives_timing()
 | 
			
		||||
        device = self.local_device
 | 
			
		||||
        a = torch.full((3, 4), float(self.rank), device=device)
 | 
			
		||||
        for i in range(2):
 | 
			
		||||
            f = pg.allreduce(a)
 | 
			
		||||
        f.wait()
 | 
			
		||||
        torch.cuda.synchronize(device=device)
 | 
			
		||||
 | 
			
		||||
        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
 | 
			
		||||
        time.sleep(1)
 | 
			
		||||
 | 
			
		||||
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
        ver = t['version']
 | 
			
		||||
        self.assertEqual(ver, "1.1")
 | 
			
		||||
        t = t['entries']
 | 
			
		||||
        self.assertEqual(len(t), 2)
 | 
			
		||||
        last = t[-1]
 | 
			
		||||
        self.assertEqual(last['process_group'], ('0', 'default_pg'))
 | 
			
		||||
        self.assertEqual(last['state'], 'completed')
 | 
			
		||||
        s = last['time_discovered_started_ns']
 | 
			
		||||
        f = last['time_discovered_completed_ns']
 | 
			
		||||
        self.assertIsNotNone(f)
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            self.assertIsNotNone(s)
 | 
			
		||||
            self.assertTrue(s <= f)
 | 
			
		||||
        self.assertIn('test_c10d_nccl.py', str(last['frames']))
 | 
			
		||||
        self.assertEqual(last['input_sizes'], ((3, 4),))
 | 
			
		||||
        self.assertEqual(last['output_sizes'], ((3, 4),))
 | 
			
		||||
        self.assertEqual(last['seq_id'], 2)
 | 
			
		||||
        now = datetime.now()
 | 
			
		||||
        event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000)
 | 
			
		||||
        event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000)
 | 
			
		||||
        before_test = now - timedelta(minutes=1)
 | 
			
		||||
        self.assertTrue(before_test < event_created_time < now)
 | 
			
		||||
        if timing_enabled:
 | 
			
		||||
            # very loose bounds, measured 0.036 ms on devgpu
 | 
			
		||||
            self.assertTrue(0 < last['duration_ms'] < 100)
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertTrue("duration_ms" not in last)
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
@ -3698,9 +3735,11 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
        f.wait()
 | 
			
		||||
        torch.cuda.synchronize(device=device)
 | 
			
		||||
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
        t = t['entries']
 | 
			
		||||
        self.assertEqual(len(t), 10)
 | 
			
		||||
        first = t[0]
 | 
			
		||||
        last = t[-1]
 | 
			
		||||
        self.assertEqual(last['profiling_name'], 'nccl:all_reduce')
 | 
			
		||||
        self.assertEqual(last['state'], 'completed')
 | 
			
		||||
        self.assertIn('test_c10d_nccl.py', str(last['frames']))
 | 
			
		||||
        self.assertEqual(last['input_sizes'], ((3, 4),))
 | 
			
		||||
@ -3732,6 +3771,8 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
                pg.allreduce(a).wait()
 | 
			
		||||
            e.synchronize()
 | 
			
		||||
            t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
            t = t['entries']
 | 
			
		||||
            self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                self.assertEqual(t[-1]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[-1]['state'], 'completed')
 | 
			
		||||
@ -3773,12 +3814,15 @@ class NCCLTraceTest(NCCLTraceTestBase):
 | 
			
		||||
                # give the other thread some time to fill the cuda buffer
 | 
			
		||||
                time.sleep(5)
 | 
			
		||||
                t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
 | 
			
		||||
                if self.rank == 0:
 | 
			
		||||
                    self.assertEqual(t[-1]['seq_id'], 1)
 | 
			
		||||
                    self.assertEqual(t[-1]['state'], 'completed')
 | 
			
		||||
                else:
 | 
			
		||||
                    self.assertEqual(t[-1]['seq_id'], 2)
 | 
			
		||||
                    self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled))
 | 
			
		||||
                    self.assertIsNone(t[-1]['time_discovered_completed_ns'])
 | 
			
		||||
                # this will eventually cause the missing rank 0
 | 
			
		||||
                # to continue which will unblock the non-zero ranks
 | 
			
		||||
                self.parent.send('next')
 | 
			
		||||
@ -3832,6 +3876,10 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    @parametrize("timing_enabled", [True, False])
 | 
			
		||||
    def test_timeout_dumps(self, timing_enabled):
 | 
			
		||||
        # dump on heartbeatmonitor thread
 | 
			
		||||
        os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000'
 | 
			
		||||
        # need rank0 to crash before looking for its output file
 | 
			
		||||
        os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1'
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for rank0 to crash before looking for its output file
 | 
			
		||||
@ -3839,6 +3887,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            self.assertEqual(self._wait_process(0, timeout=90), -6)
 | 
			
		||||
            with open(self._trace_name(rank=0), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 2)
 | 
			
		||||
                self.assertEqual(t[0]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[0]['state'], 'completed')
 | 
			
		||||
@ -3868,7 +3917,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
 | 
			
		||||
instantiate_parametrized_tests(NCCLTraceTest)
 | 
			
		||||
 | 
			
		||||
class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
    def _check_return_codes(self, elapsed_time):
 | 
			
		||||
        # the base test infra assumes processes exit with matching return codes,
 | 
			
		||||
        # but we want rank0 to abort and rank1 to exit cleanly in this test
 | 
			
		||||
@ -3877,7 +3926,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
 | 
			
		||||
    def test_timeout_dumps_on_idle_ranks(self):
 | 
			
		||||
    def test_timeout_dumps_on_stuck_ranks(self):
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for both rank0 and 1 to crash before looking for both ranks' output
 | 
			
		||||
@ -3888,20 +3937,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
 | 
			
		||||
            with open(self._trace_name(rank=0), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 2)
 | 
			
		||||
            with open(self._trace_name(rank=1), 'rb') as f:
 | 
			
		||||
                t = pickle.load(f)
 | 
			
		||||
                t = t['entries']
 | 
			
		||||
                self.assertEqual(len(t), 1)
 | 
			
		||||
                self.assertEqual(t[0]['seq_id'], 1)
 | 
			
		||||
                self.assertEqual(t[0]['state'], 'completed')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Set heartbeat timeout to a shorter one (default timeout is 2 min).
 | 
			
		||||
        os.environ[
 | 
			
		||||
            "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"
 | 
			
		||||
        ] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}"
 | 
			
		||||
        pg = self._create_process_group_nccl()
 | 
			
		||||
 | 
			
		||||
        device = self.local_device
 | 
			
		||||
        with torch.cuda.device(device):
 | 
			
		||||
            a = torch.full((3, 4), float(self.rank), device=device)
 | 
			
		||||
@ -3910,12 +3956,68 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
 | 
			
		||||
            if self.rank == 0:
 | 
			
		||||
                pg.allreduce(a).wait()
 | 
			
		||||
 | 
			
		||||
            # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
 | 
			
		||||
            # rank 0 will get stuck, timeout and then signal a timeout to all ranks.
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
            # Force rank 1 to idle so that it also gets debug info dump triggered.
 | 
			
		||||
            if self.rank == 1:
 | 
			
		||||
                time.sleep(6)
 | 
			
		||||
                # Force rank 1 to idle so that it will eventually timeout as well after
 | 
			
		||||
                # getting the global signal to dump the debugging info.
 | 
			
		||||
                time.sleep(600)
 | 
			
		||||
 | 
			
		||||
class NcclErrorDumpTest(NCCLTraceTestBase):
 | 
			
		||||
    def _wait_process(self, rank, timeout):
 | 
			
		||||
        try:
 | 
			
		||||
            self.processes[rank].join(timeout)
 | 
			
		||||
            return self.processes[rank].exitcode
 | 
			
		||||
        except TimeoutError:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def _check_return_codes(self, elapsed_time):
 | 
			
		||||
        # the base test infra assumes processes exit with matching return codes,
 | 
			
		||||
        # but we want rank0 to abort with exception and rank1 to exit with exit 1
 | 
			
		||||
        self.assertEqual(self.processes[0].exitcode, -6)
 | 
			
		||||
        self.assertEqual(self.processes[1].exitcode, 1)
 | 
			
		||||
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
 | 
			
		||||
    @skip_if_lt_x_gpu(2)
 | 
			
		||||
    @skip_if_rocm
 | 
			
		||||
    def test_nccl_errors_dump(self):
 | 
			
		||||
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
 | 
			
		||||
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000'
 | 
			
		||||
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
 | 
			
		||||
        # need rank0 to dump before abort
 | 
			
		||||
        os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '5'
 | 
			
		||||
 | 
			
		||||
        if self.rank == self.MAIN_PROCESS_RANK:
 | 
			
		||||
            # wait for both rank0 and 1 to crash before looking for dump
 | 
			
		||||
            self.assertEqual(self._wait_process(0, timeout=90), -6)
 | 
			
		||||
            self.assertEqual(self._wait_process(1, timeout=90), 1)
 | 
			
		||||
            # verify that the trace file exists for rank0
 | 
			
		||||
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        store = c10d.FileStore(self.file_name, self.world_size)
 | 
			
		||||
        process_group = c10d.ProcessGroupNCCL(
 | 
			
		||||
            store,
 | 
			
		||||
            self.rank,
 | 
			
		||||
            self.world_size,
 | 
			
		||||
            timeout=timedelta(seconds=10),
 | 
			
		||||
        )
 | 
			
		||||
        process_group.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
            # expect an error to be raised
 | 
			
		||||
            with self.assertRaisesRegex(dist.DistBackendError, ""):
 | 
			
		||||
                # Block the current stream on the NCCL stream
 | 
			
		||||
                work.wait()
 | 
			
		||||
                # Run some GPU operations
 | 
			
		||||
                a = torch.rand(10).cuda(self.rank)
 | 
			
		||||
        elif self.rank == 1:
 | 
			
		||||
            # Clean up structures (ex: files for FileStore before going down)
 | 
			
		||||
            del process_group
 | 
			
		||||
            sys.exit(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    assert (
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ class StoreTestBase:
 | 
			
		||||
    def _create_store(self, i):
 | 
			
		||||
        raise RuntimeError("not implemented")
 | 
			
		||||
 | 
			
		||||
    def _test_set_get(self, fs):
 | 
			
		||||
    def _test_set_get_check(self, fs):
 | 
			
		||||
        fs.add("key", 1)
 | 
			
		||||
        fs.add("key", 2)
 | 
			
		||||
        fs.add("key", 3)
 | 
			
		||||
@ -90,14 +90,16 @@ class StoreTestBase:
 | 
			
		||||
        self.assertEqual(b"value1", fs.get("key1"))
 | 
			
		||||
        self.assertEqual(b"value2", fs.get("key2"))
 | 
			
		||||
        self.assertEqual(b"21", fs.get("key3"))
 | 
			
		||||
        self.assertTrue(fs.check(["key3"]))
 | 
			
		||||
        self.assertFalse(fs.check(["Randomkey3"]))
 | 
			
		||||
 | 
			
		||||
        fs.set("-key3", "7")
 | 
			
		||||
        self.assertEqual(b"7", fs.get("-key3"))
 | 
			
		||||
        fs.delete_key("-key3")
 | 
			
		||||
        self.assertEqual(fs.num_keys(), self.num_keys_total)
 | 
			
		||||
 | 
			
		||||
    def test_set_get(self):
 | 
			
		||||
        self._test_set_get(self._create_store())
 | 
			
		||||
    def test_set_get_check(self):
 | 
			
		||||
        self._test_set_get_check(self._create_store())
 | 
			
		||||
 | 
			
		||||
    def _test_compare_set(self, store):
 | 
			
		||||
        missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
 | 
			
		||||
@ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
 | 
			
		||||
    def num_keys_total(self):
 | 
			
		||||
        return 6
 | 
			
		||||
 | 
			
		||||
    def test_underlying_non_prefix_store(self):
 | 
			
		||||
        store = self._create_store()
 | 
			
		||||
        wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
 | 
			
		||||
        self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
 | 
			
		||||
        self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
 | 
			
		||||
 | 
			
		||||
class MyPythonStore(dist.Store):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@ -463,6 +463,7 @@ class ProcessGroup:
 | 
			
		||||
        backend: Optional[ProcessGroup],
 | 
			
		||||
    ) -> None: ...
 | 
			
		||||
    def _set_group_name(self, name: str) -> None: ...
 | 
			
		||||
    def _set_group_desc(self, desc: str) -> None: ...
 | 
			
		||||
    def name(self) -> str: ...
 | 
			
		||||
    def _has_hooks(self) -> bool: ...
 | 
			
		||||
    def _wait_for_pending_works(self) -> None: ...
 | 
			
		||||
@ -471,6 +472,10 @@ class ProcessGroup:
 | 
			
		||||
    def bound_device_id(self) -> Optional[torch.device]: ...
 | 
			
		||||
    @bound_device_id.setter
 | 
			
		||||
    def bound_device_id(self, device: Optional[torch.device]) -> None: ...
 | 
			
		||||
    @property
 | 
			
		||||
    def group_name(self) -> str: ...
 | 
			
		||||
    @property
 | 
			
		||||
    def group_desc(self) -> str: ...
 | 
			
		||||
 | 
			
		||||
class ProcessGroupRoundRobin(ProcessGroup): ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -369,6 +369,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
 | 
			
		||||
    return pg_name_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void setGroupDesc(const std::string& desc) {
 | 
			
		||||
    pg_desc_ = desc;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const std::string& getGroupDesc() const {
 | 
			
		||||
    return pg_desc_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // See similar functions in ProcessGroup.hpp for context.
 | 
			
		||||
  c10::optional<at::Device> getBoundDeviceId() const {
 | 
			
		||||
    return bound_device_id_;
 | 
			
		||||
@ -399,6 +407,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
 | 
			
		||||
  // remains the same across use of this process group.
 | 
			
		||||
  DebugLevel dist_debug_level_;
 | 
			
		||||
  std::string pg_name_;
 | 
			
		||||
  std::string pg_desc_;
 | 
			
		||||
 | 
			
		||||
  std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
#include <c10/util/Optional.h>
 | 
			
		||||
#include <nccl.h>
 | 
			
		||||
@ -175,14 +176,29 @@ std::string getNcclErrorDetailStr(
 | 
			
		||||
    c10::optional<std::string> processGroupFailureReason = c10::nullopt);
 | 
			
		||||
 | 
			
		||||
// Write NCCL debug info to local disk or any storage users define.
 | 
			
		||||
// There are some constrains we set for the debug info writer:
 | 
			
		||||
// 1. The writer should only be registered once.
 | 
			
		||||
// 2. Once registered, users cannot change it including un-register.
 | 
			
		||||
// 3. It is recommended to register the customized writer in the trainer setup,
 | 
			
		||||
//    If users don't register before calling launchAsyncDebugDump, then users
 | 
			
		||||
//    lose the chance to register (and the default writer will be
 | 
			
		||||
//    auto-registered).
 | 
			
		||||
class TORCH_API DebugInfoWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  DebugInfoWriter(int rank);
 | 
			
		||||
  virtual ~DebugInfoWriter();
 | 
			
		||||
  virtual void write(const std::string& ncclTrace);
 | 
			
		||||
  static DebugInfoWriter& getWriter(int rank);
 | 
			
		||||
  static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  DebugInfoWriter(std::string namePrefix, int rank) {
 | 
			
		||||
    filename_ = c10::str(namePrefix, rank);
 | 
			
		||||
  }
 | 
			
		||||
  std::string filename_;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  static std::unique_ptr<DebugInfoWriter> writer_;
 | 
			
		||||
  static std::atomic<bool> hasWriterRegistered_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// RAII wrapper for NCCL communicator
 | 
			
		||||
@ -267,6 +283,18 @@ class NCCLComm {
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
 | 
			
		||||
  std::unordered_map<std::string, std::string> ncclCommDump() {
 | 
			
		||||
    std::unordered_map<std::string, std::string> dump;
 | 
			
		||||
    if (isAborted()) {
 | 
			
		||||
      LOG(INFO) << "Communicator was aborted before trying to dump its state.";
 | 
			
		||||
      return dump;
 | 
			
		||||
    }
 | 
			
		||||
    C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt);
 | 
			
		||||
    return dump;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  ncclUniqueId getNcclId() {
 | 
			
		||||
    return ncclId_;
 | 
			
		||||
  }
 | 
			
		||||
@ -322,6 +350,9 @@ class NCCLComm {
 | 
			
		||||
    // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
 | 
			
		||||
    // timeout)
 | 
			
		||||
    commFailureReason_ = commFailureReason;
 | 
			
		||||
    LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
 | 
			
		||||
              << (commFailureReason ? *commFailureReason
 | 
			
		||||
                                    : "No abort reason provided.");
 | 
			
		||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
 | 
			
		||||
    C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
 | 
			
		||||
#else
 | 
			
		||||
@ -421,6 +452,8 @@ class NCCLComm {
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  friend class ProcessGroupNCCL;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  ncclComm_t ncclComm_;
 | 
			
		||||
  // Unique nccl_id for this communicator.
 | 
			
		||||
 | 
			
		||||
@ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
 | 
			
		||||
  return store_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
 | 
			
		||||
  c10::intrusive_ptr<Store> store = store_;
 | 
			
		||||
 | 
			
		||||
  while (store) {
 | 
			
		||||
    // Attempt to dynamically cast to PrefixStore
 | 
			
		||||
    PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
 | 
			
		||||
    if (asPrefixStore) {
 | 
			
		||||
      store = asPrefixStore->getUnderlyingStore();
 | 
			
		||||
    } else {
 | 
			
		||||
      break; // We've reached a non-PrefixStore
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
 | 
			
		||||
  return store;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10d
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store {
 | 
			
		||||
 | 
			
		||||
  c10::intrusive_ptr<Store> getUnderlyingStore();
 | 
			
		||||
 | 
			
		||||
  // Recursively to fetch the store before layers of wrapping with PrefixStore.
 | 
			
		||||
  c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  std::string prefix_;
 | 
			
		||||
  c10::intrusive_ptr<Store> store_;
 | 
			
		||||
 | 
			
		||||
@ -165,6 +165,18 @@ void ProcessGroup::setGroupName(const std::string& name) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const std::string& ProcessGroup::getGroupDesc() const {
 | 
			
		||||
  return pg_desc_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroup::setGroupDesc(const std::string& name) {
 | 
			
		||||
  pg_desc_ = name;
 | 
			
		||||
  // Also set the group desc for all backends
 | 
			
		||||
  for (auto& kv : deviceTypeToBackend_) {
 | 
			
		||||
    kv.second->setGroupDesc(name);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ProcessGroup::enableCollectivesTiming() {
 | 
			
		||||
  for (auto& kv : deviceTypeToBackend_) {
 | 
			
		||||
    kv.second->enableCollectivesTiming();
 | 
			
		||||
 | 
			
		||||
@ -694,6 +694,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
 | 
			
		||||
 | 
			
		||||
  const std::string& getGroupName() const;
 | 
			
		||||
  void setGroupName(const std::string& name);
 | 
			
		||||
  const std::string& getGroupDesc() const;
 | 
			
		||||
  void setGroupDesc(const std::string& name);
 | 
			
		||||
  void enableCollectivesTiming();
 | 
			
		||||
 | 
			
		||||
  void release_resources() override;
 | 
			
		||||
@ -724,6 +726,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
 | 
			
		||||
  const int size_;
 | 
			
		||||
  const c10::intrusive_ptr<Options> options_;
 | 
			
		||||
  const BackendType backendType_;
 | 
			
		||||
  std::string pg_desc_;
 | 
			
		||||
 | 
			
		||||
  // Debug level setting. It is parsed once when ProcessGroup is constructed and
 | 
			
		||||
  // remains the same across use of this process group.
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,7 +1,15 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#if defined(__linux__)
 | 
			
		||||
#include <fcntl.h>
 | 
			
		||||
#include <sys/stat.h>
 | 
			
		||||
#include <sys/types.h>
 | 
			
		||||
#include <unistd.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <future>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
@ -12,6 +20,7 @@
 | 
			
		||||
 | 
			
		||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
 | 
			
		||||
#include <torch/csrc/distributed/c10d/Store.hpp>
 | 
			
		||||
 | 
			
		||||
#include <ATen/DynamicLibrary.h>
 | 
			
		||||
@ -26,52 +35,74 @@
 | 
			
		||||
#include <torch/custom_class.h>
 | 
			
		||||
 | 
			
		||||
namespace c10d {
 | 
			
		||||
// Environment variable which controls whether we perform a NCCL healt check
 | 
			
		||||
// Control whether we perform a NCCL health check or not
 | 
			
		||||
// which ensures communicators are healthy at the beginning of init.
 | 
			
		||||
static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = {
 | 
			
		||||
    "TORCH_ENABLE_NCCL_HEALTH_CHECK",
 | 
			
		||||
    "ENABLE_NCCL_HEALTH_CHECK"};
 | 
			
		||||
 | 
			
		||||
// Environment variable which controls whether or not wait() is blocking or
 | 
			
		||||
// non-blocking.
 | 
			
		||||
// Control whether or not wait() is blocking or non-blocking.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
 | 
			
		||||
    "TORCH_NCCL_BLOCKING_WAIT",
 | 
			
		||||
    "NCCL_BLOCKING_WAIT"};
 | 
			
		||||
 | 
			
		||||
// Environment variable which controls whether or not we perform Async Error
 | 
			
		||||
// Handling with NCCL.
 | 
			
		||||
// Control whether or not we perform Async Error Handling with NCCL.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
 | 
			
		||||
    "TORCH_NCCL_ASYNC_ERROR_HANDLING",
 | 
			
		||||
    "NCCL_ASYNC_ERROR_HANDLING"};
 | 
			
		||||
 | 
			
		||||
// Environment Variable to control whether dumping debug info on watchdog
 | 
			
		||||
// Control whether dumping debug info on watchdog
 | 
			
		||||
// timeout is enabled. This variable must be set together with
 | 
			
		||||
// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
 | 
			
		||||
    "TORCH_NCCL_DUMP_ON_TIMEOUT"};
 | 
			
		||||
 | 
			
		||||
// Environment Variable to control whether Desync Debug is enabled.
 | 
			
		||||
// This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
 | 
			
		||||
// Control whether Desync Debug is enabled. This variable must be set
 | 
			
		||||
// together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
 | 
			
		||||
    "TORCH_NCCL_DESYNC_DEBUG",
 | 
			
		||||
    "NCCL_DESYNC_DEBUG"};
 | 
			
		||||
 | 
			
		||||
// Enable recording start-events for all ProcessGroupNCCL collectives, and
 | 
			
		||||
// compute accurate collective timing per-collective. (Note: end-events are
 | 
			
		||||
// recorded by default. Turn on this flag can increase chances of a watchdog
 | 
			
		||||
// hang due to performing a CUDA event query which eventually calls
 | 
			
		||||
// cudaEventElapsedTime() API.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
 | 
			
		||||
    "TORCH_NCCL_ENABLE_TIMING",
 | 
			
		||||
    "NCCL_ENABLE_TIMING"};
 | 
			
		||||
 | 
			
		||||
// Enable monitoring thread which aborts the process when the ProcessGroupNCCL
 | 
			
		||||
// Watchdog thread gets stuck and no heartbeat is detected after
 | 
			
		||||
// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
 | 
			
		||||
// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
 | 
			
		||||
// time than necessary tying up cluster resources.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
 | 
			
		||||
    "TORCH_NCCL_ENABLE_MONITORING"};
 | 
			
		||||
 | 
			
		||||
// Control the watchdog heartbeat timeout period after which the monitoring
 | 
			
		||||
// thread will abort the process.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
 | 
			
		||||
    "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
 | 
			
		||||
 | 
			
		||||
// The maximum number of events we store in the flight recorder's ring buffer.
 | 
			
		||||
// (One event could be the start or end of a collective, for example).
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
 | 
			
		||||
    "TORCH_NCCL_TRACE_BUFFER_SIZE"};
 | 
			
		||||
 | 
			
		||||
// Control how much extra time we will wait for dumping the debugging info
 | 
			
		||||
// before we exit and throws timeout exception.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
 | 
			
		||||
    "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
 | 
			
		||||
 | 
			
		||||
// Control the interval inside the watchdog thread to check the coordinated
 | 
			
		||||
// signal from other ranks, e.g. to dump the debugging information.
 | 
			
		||||
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
 | 
			
		||||
    "TORCH_NCCL_COORD_CHECK_MILSEC"};
 | 
			
		||||
 | 
			
		||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
 | 
			
		||||
 | 
			
		||||
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
 | 
			
		||||
constexpr const char* EXCEPTION_DUMP = "exception_dump";
 | 
			
		||||
 | 
			
		||||
constexpr auto kProcessGroupNCCLDefaultTimeout =
 | 
			
		||||
    std::chrono::milliseconds(10 * 60 * 1000);
 | 
			
		||||
@ -110,6 +141,59 @@ static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
 | 
			
		||||
    {"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
 | 
			
		||||
     "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
 | 
			
		||||
 | 
			
		||||
#if defined(__linux__)
 | 
			
		||||
struct DumpPipe {
 | 
			
		||||
  DumpPipe(int rank) {
 | 
			
		||||
    std::string fileStem =
 | 
			
		||||
        getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
 | 
			
		||||
    if (fileStem.empty() ||
 | 
			
		||||
        getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
 | 
			
		||||
    std::string filename = c10::str(fileStem, rank, ".pipe");
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        unlink(filename.c_str()) != -1 || errno == ENOENT,
 | 
			
		||||
        "Error removing existing named pipe ",
 | 
			
		||||
        filename);
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        mkfifo(filename.c_str(), 0666) != -1,
 | 
			
		||||
        "Error creating named pipe ",
 | 
			
		||||
        filename);
 | 
			
		||||
    fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
 | 
			
		||||
    LOG(INFO) << "Pipe file " << filename
 | 
			
		||||
              << " has been opened, write to it to trigger NCCL Debug Dump.";
 | 
			
		||||
    TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
 | 
			
		||||
  }
 | 
			
		||||
  bool shouldDump() {
 | 
			
		||||
    if (fd_ == -1) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    char buf[128];
 | 
			
		||||
    // non-blocking from O_NONBLOCK above.
 | 
			
		||||
    // Ignore EINTR because we already will poll this
 | 
			
		||||
    // again later.
 | 
			
		||||
    ssize_t bytesRead = read(fd_, &buf, 128);
 | 
			
		||||
    return bytesRead > 0;
 | 
			
		||||
  }
 | 
			
		||||
  ~DumpPipe() {
 | 
			
		||||
    if (fd_ != -1) {
 | 
			
		||||
      close(fd_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  int fd_ = -1;
 | 
			
		||||
};
 | 
			
		||||
#else
 | 
			
		||||
struct DumpPipe {
 | 
			
		||||
  DumpPipe(int rank) {}
 | 
			
		||||
  bool shouldDump() {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// ProcessGroupNCCL implements NCCL bindings for c10d.
 | 
			
		||||
//
 | 
			
		||||
// All functions of the class are expected to be called in the same order
 | 
			
		||||
@ -205,6 +289,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
    uint64_t getSequencenumber() const override;
 | 
			
		||||
 | 
			
		||||
    const std::string& logPrefix() const;
 | 
			
		||||
 | 
			
		||||
    // Helper function that sets an exception_ptr on the WorkNCCL object.
 | 
			
		||||
    void setException(std::exception_ptr exception_ptr);
 | 
			
		||||
 | 
			
		||||
@ -358,6 +444,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
    // via `ncclCommSplit`
 | 
			
		||||
    std::shared_ptr<ProcessGroupNCCL> split_from;
 | 
			
		||||
    int64_t split_color{0};
 | 
			
		||||
    std::string group_name;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // If you wish to create multiple process groups, each with a potentially
 | 
			
		||||
@ -540,8 +627,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  void enableCollectivesTiming() override;
 | 
			
		||||
 | 
			
		||||
  // Provide an API for users to define their own ways to store NCCL debug info.
 | 
			
		||||
  void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
 | 
			
		||||
  // Helper function for iteratively aborting communicators in the provided map
 | 
			
		||||
  void abortCommsFromMap(
 | 
			
		||||
      std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
 | 
			
		||||
          ncclCommsMap,
 | 
			
		||||
      c10::optional<std::string> abortReason);
 | 
			
		||||
 | 
			
		||||
  // Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
 | 
			
		||||
  // instead of relying on ProcessGroupNCCL destructor.
 | 
			
		||||
@ -694,6 +784,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Desync debug helper
 | 
			
		||||
  void logWorkEnd(WorkNCCL& work);
 | 
			
		||||
 | 
			
		||||
  // Generates a prefix that is unique to this process group and rank, for
 | 
			
		||||
  // disambiguating logs
 | 
			
		||||
  std::string createLogPrefix() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the unique prefix created in createLogPrefix
 | 
			
		||||
  const std::string& logPrefix() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the global rank of the device. This function assumes that users
 | 
			
		||||
  // always create a default global process group(PG) which includes all
 | 
			
		||||
  // devices. It is called in the constructor of ProcessGroupNCCL, so it always
 | 
			
		||||
  // return the rank_ of the the very first PG created, aka, default global PG.
 | 
			
		||||
  const int& globalRank() const;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Function that runs as part of a separate thread aside from watchdog
 | 
			
		||||
  // thread because we need to check the heartbeat from watchdog thread
 | 
			
		||||
@ -712,6 +815,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // for dump completion.
 | 
			
		||||
  std::future<bool> launchAsyncDebugDump();
 | 
			
		||||
 | 
			
		||||
  // Helper to wait up to the specified timeout and then abandon the dump.
 | 
			
		||||
  // Logs on timeout, and asserts the future's status is as expected.
 | 
			
		||||
  void waitForDumpOrTimeout(
 | 
			
		||||
      std::future<bool>& fut,
 | 
			
		||||
      const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
 | 
			
		||||
      size_t timeout_sec = 30);
 | 
			
		||||
 | 
			
		||||
  // A helper function to wait for a future to complete or timeout.
 | 
			
		||||
  void waitForFutureOrTimeout(
 | 
			
		||||
      std::future<bool>& fut,
 | 
			
		||||
      const std::chrono::milliseconds& timeOutMilSec,
 | 
			
		||||
      const std::string& futDescription);
 | 
			
		||||
 | 
			
		||||
  // When watchdog timeout, this function will be called and return debug info
 | 
			
		||||
  // for users. For now we only get information from retrieveDesyncReport.
 | 
			
		||||
  // We are working on enabling more useful debug information for watchdog
 | 
			
		||||
@ -720,9 +836,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  static const int64_t kWatchdogThreadSleepMillis;
 | 
			
		||||
 | 
			
		||||
  // The store is used to broadcast the NCCL unique ID of rank 0.
 | 
			
		||||
  // The store is used to broadcast the NCCL unique ID of rank 0. This store
 | 
			
		||||
  // comes with prefix and it is different across ProcessGroup NCCL instances
 | 
			
		||||
  // (aka, different ProcessGroups).
 | 
			
		||||
  c10::intrusive_ptr<Store> store_;
 | 
			
		||||
 | 
			
		||||
  // Reference to the store without prefix so that keys are same across all
 | 
			
		||||
  // ProcessGroup NCCL instances and (key, value) pairs written to the store are
 | 
			
		||||
  // global.
 | 
			
		||||
  c10::intrusive_ptr<Store> globalStore_;
 | 
			
		||||
 | 
			
		||||
  bool storeError_{false};
 | 
			
		||||
 | 
			
		||||
  const c10::intrusive_ptr<Options> options_;
 | 
			
		||||
@ -781,11 +904,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  std::mutex mutex_;
 | 
			
		||||
 | 
			
		||||
  // Heartbeat of watchdog thread.
 | 
			
		||||
  uint64_t heartbeat_;
 | 
			
		||||
  std::atomic_uint64_t heartbeat_;
 | 
			
		||||
 | 
			
		||||
  // The time interval used for deciding whether there is no watchdog heartbeat.
 | 
			
		||||
  int heartbeatTimeoutInSec_;
 | 
			
		||||
 | 
			
		||||
  // Extra time of sleep when waiting for timeout dump to finish.
 | 
			
		||||
  int waitTimeoutDumpInMilSec_;
 | 
			
		||||
 | 
			
		||||
  // Interval of check coordinated signals in ProcessGroupNCCL from other ranks
 | 
			
		||||
  // e.g., trigger the dump of the debugging info for timeout when notified.
 | 
			
		||||
  int coordCheckIntervalMilSec_;
 | 
			
		||||
 | 
			
		||||
  // Size of ring buffer where we store NCCL Traces for debugging.
 | 
			
		||||
  int ncclTraceBufferSize_;
 | 
			
		||||
 | 
			
		||||
@ -815,6 +945,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Whether there are hooks pending to be fired
 | 
			
		||||
  std::atomic<bool> hasPendingHooks_;
 | 
			
		||||
 | 
			
		||||
  // This is the signal from watchdog threads to indicate whether the monitor
 | 
			
		||||
  // thread should dump. Making it static so that it is accessiable from all the
 | 
			
		||||
  // PGs. With this flag, monitor thread would dump debug info under any one of
 | 
			
		||||
  // the 3 conditions: 1: this flag is set to true by the watchdog thread when
 | 
			
		||||
  // it detects a timeout. 2: timeout signal is received from
 | 
			
		||||
  // other ranks through tcpstore 3: no heartbeat of watchdog Note that only the
 | 
			
		||||
  // monitor thread from PG0 should dump the debug info and only once
 | 
			
		||||
  static std::atomic<bool> shouldDump_;
 | 
			
		||||
 | 
			
		||||
  // Mutex to Guard workMetaList_
 | 
			
		||||
  std::mutex workMetaListMutex_;
 | 
			
		||||
 | 
			
		||||
@ -823,9 +962,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  bool writeDebugInfo_ = false;
 | 
			
		||||
 | 
			
		||||
  // Mutex to Guard the check of writeDebugInfo_
 | 
			
		||||
  std::mutex writeDebugInfoMutex_;
 | 
			
		||||
 | 
			
		||||
  // Condition Variable for watchdog thread sleep
 | 
			
		||||
  std::condition_variable workMetaListCV_;
 | 
			
		||||
 | 
			
		||||
@ -902,8 +1038,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
  // Whether or not to enable timeout root cause analysis.
 | 
			
		||||
  bool desyncDebug_;
 | 
			
		||||
 | 
			
		||||
  // Whether or not to dump debug info on timeout
 | 
			
		||||
  bool dumpOnTimeout_;
 | 
			
		||||
  // Whether or not to dump debug info on exception including both watchdog
 | 
			
		||||
  // timeout and nccl errors.
 | 
			
		||||
  bool dumpOnException_;
 | 
			
		||||
 | 
			
		||||
  // Whether or not to create start CUDAEvent and enable timing for start
 | 
			
		||||
  // and end events. Note that enableTiming_ is always true if desyncDebug_
 | 
			
		||||
@ -929,14 +1066,25 @@ class TORCH_API ProcessGroupNCCL : public Backend {
 | 
			
		||||
 | 
			
		||||
  std::exception_ptr watchDogException_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  // The callback function to store NCCL debug info.
 | 
			
		||||
  std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
 | 
			
		||||
 | 
			
		||||
  size_t uid_;
 | 
			
		||||
 | 
			
		||||
  std::string logPrefix_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_API std::string dump_nccl_trace();
 | 
			
		||||
 | 
			
		||||
// Gets a mutable reference to a global optional function.  Heartbeat Monitor
 | 
			
		||||
// will query this function and if available, call it to dump traces. Inside
 | 
			
		||||
// fbcode, we store a function here that uses an internal tool for process
 | 
			
		||||
// tracing
 | 
			
		||||
TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper();
 | 
			
		||||
 | 
			
		||||
// Similar to get_cpp_trace_dumper, this stores a function defined in
 | 
			
		||||
// torch-python layer that lets us check whether the GIL can be acquired,
 | 
			
		||||
// helpful for instrumenting in cases where a hang was observed.
 | 
			
		||||
typedef bool (*gil_checker_t)();
 | 
			
		||||
 | 
			
		||||
TORCH_API gil_checker_t& get_gil_checker();
 | 
			
		||||
} // namespace c10d
 | 
			
		||||
 | 
			
		||||
#endif // USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <system_error>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace c10d {
 | 
			
		||||
 | 
			
		||||
/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */
 | 
			
		||||
@ -269,10 +268,20 @@ inline std::string retrieveDesyncReport(
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter::DebugInfoWriter(int rank) {
 | 
			
		||||
  std::string fileName = getCvarString(
 | 
			
		||||
      {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
  filename_ = c10::str(fileName, rank);
 | 
			
		||||
/* Helper used by work::getDuration() and nccl flight recorder */
 | 
			
		||||
float getDurationFromFirstEvent(
 | 
			
		||||
    const std::vector<at::cuda::CUDAEvent>& ncclStartEvents,
 | 
			
		||||
    const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclStartEvents.size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup, but found multiple start events.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents.size() == 1,
 | 
			
		||||
      "getDuration only works for single device per ProcessGroup, but found multiple end events.");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      ncclEndEvents[0].query(),
 | 
			
		||||
      "getDuration can only be called after work is succeeded.")
 | 
			
		||||
  return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter::~DebugInfoWriter() = default;
 | 
			
		||||
@ -293,6 +302,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) {
 | 
			
		||||
  LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
 | 
			
		||||
  if (writer_ == nullptr) {
 | 
			
		||||
    std::string fileNamePrefix = getCvarString(
 | 
			
		||||
        {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
 | 
			
		||||
    // Using std::unique_ptr here to auto-delete the writer object
 | 
			
		||||
    // when the pointer itself is destroyed.
 | 
			
		||||
    std::unique_ptr<DebugInfoWriter> writerPtr(
 | 
			
		||||
        new DebugInfoWriter(fileNamePrefix, rank));
 | 
			
		||||
    DebugInfoWriter::registerWriter(std::move(writerPtr));
 | 
			
		||||
  }
 | 
			
		||||
  return *writer_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
 | 
			
		||||
  TORCH_CHECK_WITH(
 | 
			
		||||
      DistBackendError,
 | 
			
		||||
      hasWriterRegistered_.load() == false,
 | 
			
		||||
      "debugInfoWriter already registered");
 | 
			
		||||
  hasWriterRegistered_.store(true);
 | 
			
		||||
  writer_ = std::move(writer);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
 | 
			
		||||
std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
 | 
			
		||||
 | 
			
		||||
inline std::string pickle_str(const c10::IValue& v) {
 | 
			
		||||
  std::vector<char> result;
 | 
			
		||||
  {
 | 
			
		||||
@ -317,6 +351,18 @@ inline c10::List<c10::IValue> new_list() {
 | 
			
		||||
  return c10::List<c10::IValue>(c10::AnyType::get());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline std::string ranks_str(const std::vector<uint64_t>& ranks) {
 | 
			
		||||
  std::string str;
 | 
			
		||||
  for (const auto& rank : ranks) {
 | 
			
		||||
    if (str.empty()) {
 | 
			
		||||
      str = std::to_string(rank);
 | 
			
		||||
    } else {
 | 
			
		||||
      str += ", " + std::to_string(rank);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return c10::str("[", str, "]");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct NCCLTraceBuffer {
 | 
			
		||||
  static NCCLTraceBuffer* get() {
 | 
			
		||||
    // intentionally leak on exit
 | 
			
		||||
@ -336,11 +382,12 @@ struct NCCLTraceBuffer {
 | 
			
		||||
                // buffer this entry will be located to
 | 
			
		||||
                // update state information
 | 
			
		||||
    size_t pg_id_;
 | 
			
		||||
    std::string pg_name_;
 | 
			
		||||
    size_t seq_id_; // as tracked by the process group
 | 
			
		||||
    const char* profiling_name_;
 | 
			
		||||
    std::string profiling_name_;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<torch::CapturedTraceback> traceback_;
 | 
			
		||||
    // we borrow pointser to start_ and end_ so we can query the state
 | 
			
		||||
    // we borrow pointers to start_ and end_ so we can query the state
 | 
			
		||||
    // on reporting. However, once the event is completed, the call
 | 
			
		||||
    // to `complete` will clear these.
 | 
			
		||||
    EventList *start_, *end_;
 | 
			
		||||
@ -348,8 +395,18 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    // timestamp when the entry was created, likely close to the time the work
 | 
			
		||||
    // was 'enqueued'- not necessarily started
 | 
			
		||||
    c10::time_t time_created_;
 | 
			
		||||
    c10::optional<float> duration_;
 | 
			
		||||
 | 
			
		||||
    const char* state_ = "scheduled";
 | 
			
		||||
    // timestamp when our CPU threads discovered that the kernel started.
 | 
			
		||||
    // will always be _after_ it actually started, and can be very late
 | 
			
		||||
    // if the watchdog thread got stuck on CUDA APIs.
 | 
			
		||||
    c10::optional<c10::time_t> time_discovered_started_;
 | 
			
		||||
 | 
			
		||||
    // timestamp when our CPU threads discovered that the kernel completed.
 | 
			
		||||
    // will always be _after_ it actually complated, and can be the same time
 | 
			
		||||
    // as the discovery of the start if the watchdog thread is stuck on CUDA
 | 
			
		||||
    // APIs
 | 
			
		||||
    c10::optional<c10::time_t> time_discovered_completed_;
 | 
			
		||||
 | 
			
		||||
    // size information for input/output tensors
 | 
			
		||||
    c10::SmallVector<int, 4> input_dims_;
 | 
			
		||||
@ -369,8 +426,9 @@ struct NCCLTraceBuffer {
 | 
			
		||||
 | 
			
		||||
  c10::optional<size_t> record(
 | 
			
		||||
      size_t pg_id,
 | 
			
		||||
      const std::string& pg_name,
 | 
			
		||||
      size_t seq_id,
 | 
			
		||||
      const char* profiling_name,
 | 
			
		||||
      std::string profiling_name,
 | 
			
		||||
      const std::vector<at::Tensor>& inputs,
 | 
			
		||||
      const std::vector<at::Tensor>& outputs,
 | 
			
		||||
      EventList* start,
 | 
			
		||||
@ -385,8 +443,9 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    auto te = Entry{
 | 
			
		||||
        id_,
 | 
			
		||||
        pg_id,
 | 
			
		||||
        pg_name,
 | 
			
		||||
        seq_id,
 | 
			
		||||
        profiling_name,
 | 
			
		||||
        std::move(profiling_name),
 | 
			
		||||
        std::move(traceback),
 | 
			
		||||
        std::move(start),
 | 
			
		||||
        std::move(end),
 | 
			
		||||
@ -424,8 +483,8 @@ struct NCCLTraceBuffer {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (started) {
 | 
			
		||||
        r.state_ = "started";
 | 
			
		||||
      if (started && !r.time_discovered_started_) {
 | 
			
		||||
        r.time_discovered_started_ = c10::getTime();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (r.end_ != nullptr) {
 | 
			
		||||
@ -436,8 +495,8 @@ struct NCCLTraceBuffer {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (completed) {
 | 
			
		||||
        r.state_ = "completed";
 | 
			
		||||
      if (completed && !r.time_discovered_completed_) {
 | 
			
		||||
        r.time_discovered_completed_ = c10::getTime();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -456,35 +515,97 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void retire_id(c10::optional<size_t> id) {
 | 
			
		||||
  /*
 | 
			
		||||
  Mark an Event as completed and free its events.
 | 
			
		||||
 | 
			
		||||
  This is called by the watchdog thread, and is asynchronous from the
 | 
			
		||||
  perspective of the main thread.
 | 
			
		||||
 | 
			
		||||
  compute_duration defaults to true since retire_id is only called in the
 | 
			
		||||
  watchdog thread, which is currently a place we call cuda APIs which may hang,
 | 
			
		||||
  but care should be taken to avoid computing duration in any function that must
 | 
			
		||||
  never hang. (timing must also be enabled for compute_duration - see
 | 
			
		||||
  TORCH_NCCL_ENABLE_TIMING).
 | 
			
		||||
  */
 | 
			
		||||
  void retire_id(c10::optional<size_t> id, bool compute_duration = true) {
 | 
			
		||||
    if (!enabled_ || !id) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    std::lock_guard<std::mutex> guard(mutex_);
 | 
			
		||||
    auto& entry = entries_.at(*id % max_entries_);
 | 
			
		||||
    if (entry.id_ == *id) {
 | 
			
		||||
      update_state(entry);
 | 
			
		||||
      entry.retired_ = true;
 | 
			
		||||
      entry.start_ = entry.end_ = nullptr;
 | 
			
		||||
 | 
			
		||||
    bool can_compute_duration = false;
 | 
			
		||||
    EventList* startEvents = nullptr;
 | 
			
		||||
    EventList* endEvents = nullptr;
 | 
			
		||||
    c10::optional<float> duration = c10::nullopt;
 | 
			
		||||
 | 
			
		||||
    std::unique_lock<std::mutex> guard(mutex_);
 | 
			
		||||
 | 
			
		||||
    Entry* entry = &entries_.at(*id % max_entries_);
 | 
			
		||||
    if (entry->id_ == *id) {
 | 
			
		||||
      update_state(*entry);
 | 
			
		||||
 | 
			
		||||
      if (compute_duration) {
 | 
			
		||||
        can_compute_duration = entry->time_discovered_completed_.has_value() &&
 | 
			
		||||
            entry->start_ && entry->end_;
 | 
			
		||||
        startEvents = entry->start_;
 | 
			
		||||
        endEvents = entry->end_;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (can_compute_duration) {
 | 
			
		||||
      // Compute duration without without holding the lock, because
 | 
			
		||||
      // cudaEventDuration() can hang, and we need to acquire the lock before we
 | 
			
		||||
      // can dump(), which we never want to block.
 | 
			
		||||
      guard.unlock();
 | 
			
		||||
      duration = getDurationFromFirstEvent(*startEvents, *endEvents);
 | 
			
		||||
      guard.lock();
 | 
			
		||||
 | 
			
		||||
      // Refresh the entry pointer, see if the entry has been overwritten
 | 
			
		||||
      entry = &entries_.at(*id % max_entries_);
 | 
			
		||||
      if (entry->id_ != *id) {
 | 
			
		||||
        LOG(INFO)
 | 
			
		||||
            << "retire_id abandoned for id " << *id
 | 
			
		||||
            << ", event was overwritten while waiting to compute duration.";
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      if (duration.has_value()) {
 | 
			
		||||
        entry->duration_ = duration.value();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    entry->retired_ = true;
 | 
			
		||||
    entry->start_ = entry->end_ = nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string dump() {
 | 
			
		||||
  std::string dump(
 | 
			
		||||
      const c10::optional<std::unordered_map<
 | 
			
		||||
          std::string,
 | 
			
		||||
          std::unordered_map<std::string, std::string>>>& ncclDumpMap) {
 | 
			
		||||
    auto result = dump_entries();
 | 
			
		||||
    auto entries = new_list();
 | 
			
		||||
    c10::IValue pg_id_s = "pg_id";
 | 
			
		||||
    c10::IValue seq_id_s = "seq_id";
 | 
			
		||||
    c10::IValue profiling_name_s = "profiling_name";
 | 
			
		||||
    c10::IValue input_sizes_s = "input_sizes";
 | 
			
		||||
    c10::IValue output_sizes_s = "output_sizes";
 | 
			
		||||
    c10::IValue time_created_s = "time_created_us";
 | 
			
		||||
    c10::IValue entries_key = "entries";
 | 
			
		||||
    c10::IValue nccl_comm_key = "nccl_comm_state";
 | 
			
		||||
    c10::IValue version_key = "version";
 | 
			
		||||
    // Update whenever changing contents or formatting of the dump
 | 
			
		||||
    // (minor when adding fields, major when changing existing fields)
 | 
			
		||||
    c10::IValue version_val = "1.1";
 | 
			
		||||
 | 
			
		||||
    c10::IValue frames_s = "frames";
 | 
			
		||||
    c10::IValue state_s = "state";
 | 
			
		||||
    c10::IValue line_s = "line";
 | 
			
		||||
    c10::IValue name_s = "name";
 | 
			
		||||
    c10::IValue filename_s = "filename";
 | 
			
		||||
    c10::IValue retired_s = "retired";
 | 
			
		||||
    c10::IValue pg_id_key = "pg_id";
 | 
			
		||||
    c10::IValue pg_name_key = "process_group";
 | 
			
		||||
    c10::IValue seq_id_key = "seq_id";
 | 
			
		||||
    c10::IValue profiling_name_key = "profiling_name";
 | 
			
		||||
    c10::IValue input_sizes_key = "input_sizes";
 | 
			
		||||
    c10::IValue output_sizes_key = "output_sizes";
 | 
			
		||||
    c10::IValue time_created_key = "time_created_ns";
 | 
			
		||||
    c10::IValue duration_key = "duration_ms";
 | 
			
		||||
 | 
			
		||||
    c10::IValue frames_key = "frames";
 | 
			
		||||
    c10::IValue state_key = "state";
 | 
			
		||||
    c10::IValue line_key = "line";
 | 
			
		||||
    c10::IValue name_key = "name";
 | 
			
		||||
    c10::IValue filename_key = "filename";
 | 
			
		||||
    c10::IValue retired_key = "retired";
 | 
			
		||||
    c10::IValue time_discovered_started_key = "time_discovered_started_ns";
 | 
			
		||||
    c10::IValue time_discovered_completed_key = "time_discovered_completed_ns";
 | 
			
		||||
 | 
			
		||||
    std::vector<torch::CapturedTraceback*> tracebacks;
 | 
			
		||||
    for (auto& e : result) {
 | 
			
		||||
@ -494,9 +615,9 @@ struct NCCLTraceBuffer {
 | 
			
		||||
    std::vector<c10::IValue> all_frames;
 | 
			
		||||
    for (const auto& f : stracebacks.all_frames) {
 | 
			
		||||
      auto d = new_dict();
 | 
			
		||||
      d.insert(name_s, f.funcname);
 | 
			
		||||
      d.insert(filename_s, f.filename);
 | 
			
		||||
      d.insert(line_s, int64_t(f.lineno));
 | 
			
		||||
      d.insert(name_key, f.funcname);
 | 
			
		||||
      d.insert(filename_key, f.filename);
 | 
			
		||||
      d.insert(line_key, int64_t(f.lineno));
 | 
			
		||||
      all_frames.emplace_back(std::move(d));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -504,10 +625,14 @@ struct NCCLTraceBuffer {
 | 
			
		||||
      auto& e = result.at(i);
 | 
			
		||||
      auto& tb = stracebacks.tracebacks.at(i);
 | 
			
		||||
      auto dict = new_dict();
 | 
			
		||||
      dict.insert(pg_id_s, int64_t(e.pg_id_));
 | 
			
		||||
      dict.insert(seq_id_s, int64_t(e.seq_id_));
 | 
			
		||||
      dict.insert(profiling_name_s, e.profiling_name_);
 | 
			
		||||
      dict.insert(time_created_s, int64_t(e.time_created_ / 1000));
 | 
			
		||||
      dict.insert(pg_id_key, int64_t(e.pg_id_));
 | 
			
		||||
      dict.insert(pg_name_key, e.pg_name_);
 | 
			
		||||
      dict.insert(seq_id_key, int64_t(e.seq_id_));
 | 
			
		||||
      dict.insert(profiling_name_key, e.profiling_name_);
 | 
			
		||||
      dict.insert(time_created_key, int64_t(e.time_created_));
 | 
			
		||||
      if (e.duration_) {
 | 
			
		||||
        dict.insert(duration_key, *e.duration_);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto it = e.sizes_.begin();
 | 
			
		||||
      auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
 | 
			
		||||
@ -523,19 +648,55 @@ struct NCCLTraceBuffer {
 | 
			
		||||
        return sizes;
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      dict.insert(input_sizes_s, read_sizes(e.input_dims_));
 | 
			
		||||
      dict.insert(output_sizes_s, read_sizes(e.output_dims_));
 | 
			
		||||
      dict.insert(state_s, e.state_);
 | 
			
		||||
      dict.insert(retired_s, e.retired_);
 | 
			
		||||
      dict.insert(input_sizes_key, read_sizes(e.input_dims_));
 | 
			
		||||
      dict.insert(output_sizes_key, read_sizes(e.output_dims_));
 | 
			
		||||
      if (e.time_discovered_completed_.has_value()) {
 | 
			
		||||
        dict.insert(state_key, "completed");
 | 
			
		||||
      } else if (e.time_discovered_started_.has_value()) {
 | 
			
		||||
        dict.insert(state_key, "started");
 | 
			
		||||
      } else {
 | 
			
		||||
        dict.insert(state_key, "scheduled");
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      dict.insert(
 | 
			
		||||
          time_discovered_started_key,
 | 
			
		||||
          e.time_discovered_started_.has_value()
 | 
			
		||||
              ? int64_t(*e.time_discovered_started_)
 | 
			
		||||
              : c10::IValue());
 | 
			
		||||
      dict.insert(
 | 
			
		||||
          time_discovered_completed_key,
 | 
			
		||||
          e.time_discovered_completed_.has_value()
 | 
			
		||||
              ? int64_t(*e.time_discovered_completed_)
 | 
			
		||||
              : c10::IValue());
 | 
			
		||||
      dict.insert(retired_key, e.retired_);
 | 
			
		||||
 | 
			
		||||
      auto frames = new_list();
 | 
			
		||||
      for (int64_t frame : tb) {
 | 
			
		||||
        frames.push_back(all_frames.at(frame));
 | 
			
		||||
      }
 | 
			
		||||
      dict.insert(frames_s, frames);
 | 
			
		||||
      dict.insert(frames_key, frames);
 | 
			
		||||
      entries.push_back(dict);
 | 
			
		||||
    }
 | 
			
		||||
    return pickle_str(entries);
 | 
			
		||||
    // convert ncclDumpMap into a dictionary
 | 
			
		||||
    auto per_comm_dict = new_dict();
 | 
			
		||||
    if (ncclDumpMap.has_value()) {
 | 
			
		||||
      for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
 | 
			
		||||
        auto inner_dict = new_dict();
 | 
			
		||||
        for (const auto& [key, value] : ncclDump) {
 | 
			
		||||
          inner_dict.insert(key, value);
 | 
			
		||||
        }
 | 
			
		||||
        per_comm_dict.insert(ncclId, inner_dict);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto dict = new_dict();
 | 
			
		||||
    dict.insert(entries_key, entries);
 | 
			
		||||
    dict.insert(version_key, version_val);
 | 
			
		||||
    if (per_comm_dict.size() > 0) {
 | 
			
		||||
      dict.insert(nccl_comm_key, per_comm_dict);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return pickle_str(dict);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,35 @@
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
#ifdef USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
bool acquire_gil() {
 | 
			
		||||
  // basically if this function can acquire the gil, it will return quickly.
 | 
			
		||||
  // if not, it will hang forever.  The idea is to call this from a thread
 | 
			
		||||
  // wrapped in a future, and then check the future after a timeout, to
 | 
			
		||||
  // determine whether we're facing gil contention.
 | 
			
		||||
  if (Py_IsInitialized()) {
 | 
			
		||||
    pybind11::gil_scoped_acquire gil;
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If we end up here, its probably still a "pass" from the perspective of
 | 
			
		||||
  // checking whether python is stuck. but currently we don't check the return
 | 
			
		||||
  // value of this function anyway, just check whether it returned quickly vs
 | 
			
		||||
  // timing out.  Taking a long time is the main sign of trouble.  Fast return
 | 
			
		||||
  // with true or with false is both OK from the perspective of debugging python
 | 
			
		||||
  // hangs.
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool registerGilChecker() {
 | 
			
		||||
  c10d::get_gil_checker() = &acquire_gil;
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool registered = registerGilChecker();
 | 
			
		||||
#endif // USE_C10D_NCCL
 | 
			
		||||
 | 
			
		||||
// Wrapper to ensure GIL is released before destructing ProcessGroupGloo
 | 
			
		||||
// TODO: move this somewhere more generally useful
 | 
			
		||||
template <typename T>
 | 
			
		||||
@ -1033,6 +1062,29 @@ Example::
 | 
			
		||||
    >>> store.add("first_key", 6)
 | 
			
		||||
    >>> # Should return 7
 | 
			
		||||
    >>> store.get("first_key")
 | 
			
		||||
)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "check",
 | 
			
		||||
              &::c10d::Store::check,
 | 
			
		||||
              py::call_guard<py::gil_scoped_release>(),
 | 
			
		||||
              R"(
 | 
			
		||||
The call to check whether a given list of ``keys`` have value stored in
 | 
			
		||||
the store. This call immediately returns in normal cases but still suffers
 | 
			
		||||
from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
 | 
			
		||||
Calling :meth:`~torch.distributed.store.check` with a list of keys that
 | 
			
		||||
one wants to check whether stored in the store or not.
 | 
			
		||||
 | 
			
		||||
Arguments:
 | 
			
		||||
    keys (lisr[str]): The keys to query whether stored in the store.
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
    >>> import torch.distributed as dist
 | 
			
		||||
    >>> from datetime import timedelta
 | 
			
		||||
    >>> # Using TCPStore as an example, other store types can also be used
 | 
			
		||||
    >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
 | 
			
		||||
    >>> store.add("first_key", 1)
 | 
			
		||||
    >>> # Should return 7
 | 
			
		||||
    >>> store.check(["first_key"])
 | 
			
		||||
)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "delete_key",
 | 
			
		||||
@ -1404,7 +1456,11 @@ Arguments:
 | 
			
		||||
      .def_property_readonly(
 | 
			
		||||
          "underlying_store",
 | 
			
		||||
          &::c10d::PrefixStore::getUnderlyingStore,
 | 
			
		||||
          R"(Gets the underlying store object that PrefixStore wraps around.)");
 | 
			
		||||
          R"(Gets the underlying store object that PrefixStore wraps around.)")
 | 
			
		||||
      .def_property_readonly(
 | 
			
		||||
          "_underlying_non_prefix_store",
 | 
			
		||||
          &::c10d::PrefixStore::getUnderlyingNonPrefixStore,
 | 
			
		||||
          R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
 | 
			
		||||
 | 
			
		||||
  auto processGroup =
 | 
			
		||||
      py::class_<
 | 
			
		||||
@ -1807,6 +1863,15 @@ Arguments:
 | 
			
		||||
              "group_name",
 | 
			
		||||
              &::c10d::ProcessGroup::getGroupName,
 | 
			
		||||
              "(Gets this process group name. It's cluster unique)")
 | 
			
		||||
          .def(
 | 
			
		||||
              "_set_group_desc",
 | 
			
		||||
              &::c10d::ProcessGroup::setGroupDesc,
 | 
			
		||||
              py::call_guard<py::gil_scoped_acquire>(),
 | 
			
		||||
              "Sets the process group description. This is an internal C10D method, do not use.")
 | 
			
		||||
          .def_property_readonly(
 | 
			
		||||
              "group_desc",
 | 
			
		||||
              &::c10d::ProcessGroup::getGroupDesc,
 | 
			
		||||
              "Gets this process group description")
 | 
			
		||||
          .def_property(
 | 
			
		||||
              "bound_device_id",
 | 
			
		||||
              &::c10d::ProcessGroup::getBoundDeviceId,
 | 
			
		||||
@ -2387,7 +2452,9 @@ Example::
 | 
			
		||||
      .def_readwrite(
 | 
			
		||||
          "split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
 | 
			
		||||
      .def_readwrite(
 | 
			
		||||
          "split_color", &::c10d::ProcessGroupNCCL::Options::split_color);
 | 
			
		||||
          "split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
 | 
			
		||||
      .def_readwrite(
 | 
			
		||||
          "group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										448
									
								
								torch/distributed/_state_dict_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										448
									
								
								torch/distributed/_state_dict_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,448 @@
 | 
			
		||||
import io
 | 
			
		||||
import math
 | 
			
		||||
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
 | 
			
		||||
 | 
			
		||||
if dist.is_available() or TYPE_CHECKING:
 | 
			
		||||
    from torch.distributed import distributed_c10d
 | 
			
		||||
    from torch.distributed._shard.sharded_tensor import ShardedTensor
 | 
			
		||||
    from torch.distributed._tensor import DTensor, Replicate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _identity_func(
 | 
			
		||||
    obj: torch.Tensor,
 | 
			
		||||
    pg: Optional[dist.ProcessGroup],
 | 
			
		||||
    device: Optional[torch.device],
 | 
			
		||||
    companion_obj: Any,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    return obj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _all_gather_sharded_tensor(
 | 
			
		||||
    sharded_tensor: "ShardedTensor",
 | 
			
		||||
    pg: Optional[dist.ProcessGroup] = None,
 | 
			
		||||
    device: Optional[torch.device] = None,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    if pg is None:
 | 
			
		||||
        pg = distributed_c10d._get_default_group()
 | 
			
		||||
    world_size = dist.get_world_size(pg)
 | 
			
		||||
    shards = sharded_tensor.local_shards()
 | 
			
		||||
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
 | 
			
		||||
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
 | 
			
		||||
    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
 | 
			
		||||
    pg_device = (
 | 
			
		||||
        distributed_c10d._get_pg_default_device(pg) if device is None else device
 | 
			
		||||
    )
 | 
			
		||||
    if shards:
 | 
			
		||||
        local_tensor = shards[0].tensor.flatten()
 | 
			
		||||
        if local_tensor.device.type != pg_device.type:
 | 
			
		||||
            local_tensor = local_tensor.to(pg_device)
 | 
			
		||||
        num_padding = chunk_size - local_tensor.numel()
 | 
			
		||||
        if num_padding > 0:
 | 
			
		||||
            local_tensor = F.pad(local_tensor, [0, num_padding])
 | 
			
		||||
    else:
 | 
			
		||||
        local_tensor = torch.zeros(
 | 
			
		||||
            chunk_size, dtype=sharded_tensor.dtype, device=pg_device
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    tensor = torch.empty(
 | 
			
		||||
        chunk_size * world_size,
 | 
			
		||||
        dtype=local_tensor.dtype,
 | 
			
		||||
        device=pg_device,
 | 
			
		||||
    )
 | 
			
		||||
    dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
 | 
			
		||||
 | 
			
		||||
    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
 | 
			
		||||
    return tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CompanionMismatch(Exception):
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _iterate_state_dict(
 | 
			
		||||
    iter_object: Any,
 | 
			
		||||
    sharded_tensor_func: Callable,
 | 
			
		||||
    dtensor_func: Callable,
 | 
			
		||||
    tensor_func: Callable,
 | 
			
		||||
    *,
 | 
			
		||||
    pg: Optional[dist.ProcessGroup] = None,
 | 
			
		||||
    device: Optional[torch.device] = None,
 | 
			
		||||
    cpu_offload: bool = False,
 | 
			
		||||
    companion_obj: Any = None,
 | 
			
		||||
    ranks_only: Tuple[int, ...] = tuple(),
 | 
			
		||||
    type_check: bool = True,
 | 
			
		||||
    non_blocking: bool = True,
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """Iterate through the state dict, applying the given functions to each tensor type.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        iter_object (Any): the target state_dict.
 | 
			
		||||
        sharded_tensor_func (Callable): the function to apply to ShardedTensor
 | 
			
		||||
        dtensor_func (Callable): the function to apply to DTensor
 | 
			
		||||
        tensor_func (Callable): the function to apply to Tensor
 | 
			
		||||
        pg (Optional[dist.ProcessGroup]): process group passed to tensor functions
 | 
			
		||||
        device (Optional[torch.device]): device passed to tensor functions
 | 
			
		||||
        cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored
 | 
			
		||||
            if a companion_obj is supplied.
 | 
			
		||||
        companion_obj (Any): A companion object to the state dict. If this object
 | 
			
		||||
            is supplied, we attempt to copy the tensor to the companion object.
 | 
			
		||||
        ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will
 | 
			
		||||
            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
 | 
			
		||||
            have the same state_dicts. Other ranks will get empty state_dicts.
 | 
			
		||||
        type_check (bool): check if the instance data type is a supported type
 | 
			
		||||
            that can be saved by DCP.  The current supported data types are
 | 
			
		||||
            torch.Tensor, DTensor, int, float, str, list, dict, None.
 | 
			
		||||
        non_blocking (bool): whether to use non-blocking copy when copying to the companion object.
 | 
			
		||||
    """
 | 
			
		||||
    # TODO: should we use pytree?
 | 
			
		||||
    cpu_device = torch.device("cpu")
 | 
			
		||||
    if isinstance(iter_object, ShardedTensor):
 | 
			
		||||
        ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
 | 
			
		||||
    elif isinstance(iter_object, DTensor):
 | 
			
		||||
        ret = dtensor_func(iter_object, pg, device, companion_obj)
 | 
			
		||||
    elif isinstance(iter_object, torch.Tensor):
 | 
			
		||||
        ret = tensor_func(iter_object, pg, device, companion_obj)
 | 
			
		||||
    elif (
 | 
			
		||||
        isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
 | 
			
		||||
        or iter_object is None
 | 
			
		||||
    ):
 | 
			
		||||
        ret = iter_object
 | 
			
		||||
    elif isinstance(iter_object, dict):
 | 
			
		||||
        if companion_obj is not None and (
 | 
			
		||||
            not isinstance(companion_obj, dict)
 | 
			
		||||
            or set(companion_obj.keys()) != set(iter_object.keys())
 | 
			
		||||
        ):
 | 
			
		||||
            raise CompanionMismatch()
 | 
			
		||||
 | 
			
		||||
        ret = {
 | 
			
		||||
            key: _iterate_state_dict(
 | 
			
		||||
                value,
 | 
			
		||||
                sharded_tensor_func,
 | 
			
		||||
                dtensor_func,
 | 
			
		||||
                tensor_func,
 | 
			
		||||
                pg=pg,
 | 
			
		||||
                device=device,
 | 
			
		||||
                cpu_offload=cpu_offload,
 | 
			
		||||
                companion_obj=companion_obj[key] if companion_obj is not None else None,
 | 
			
		||||
                ranks_only=ranks_only,
 | 
			
		||||
                type_check=type_check,
 | 
			
		||||
                non_blocking=non_blocking,
 | 
			
		||||
            )
 | 
			
		||||
            for key, value in iter_object.items()
 | 
			
		||||
        }
 | 
			
		||||
    elif isinstance(iter_object, (list, tuple)):
 | 
			
		||||
        if companion_obj is not None and (
 | 
			
		||||
            not isinstance(companion_obj, (list, tuple))
 | 
			
		||||
            or len(companion_obj) != len(iter_object)
 | 
			
		||||
        ):
 | 
			
		||||
            raise CompanionMismatch()
 | 
			
		||||
 | 
			
		||||
        ret = [
 | 
			
		||||
            _iterate_state_dict(
 | 
			
		||||
                v,
 | 
			
		||||
                sharded_tensor_func,
 | 
			
		||||
                dtensor_func,
 | 
			
		||||
                tensor_func,
 | 
			
		||||
                pg=pg,
 | 
			
		||||
                device=device,
 | 
			
		||||
                cpu_offload=cpu_offload,
 | 
			
		||||
                companion_obj=companion_obj[idx] if companion_obj is not None else None,
 | 
			
		||||
                ranks_only=ranks_only,
 | 
			
		||||
                type_check=type_check,
 | 
			
		||||
                non_blocking=non_blocking,
 | 
			
		||||
            )
 | 
			
		||||
            for idx, v in enumerate(iter_object)
 | 
			
		||||
        ]
 | 
			
		||||
        if isinstance(iter_object, tuple):
 | 
			
		||||
            ret = tuple(ret)
 | 
			
		||||
    elif not type_check:
 | 
			
		||||
        ret = iter_object
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unexpected value type {type(iter_object)}")
 | 
			
		||||
 | 
			
		||||
    if not ranks_only or dist.get_rank(pg) in ranks_only:
 | 
			
		||||
        if isinstance(ret, torch.Tensor):
 | 
			
		||||
            if cpu_offload and companion_obj is None:
 | 
			
		||||
                ret = ret.to(cpu_device)
 | 
			
		||||
 | 
			
		||||
            if companion_obj is not None:
 | 
			
		||||
                # TODO: support DTensor
 | 
			
		||||
                companion_obj.copy_(ret, non_blocking=non_blocking)
 | 
			
		||||
                ret = companion_obj
 | 
			
		||||
    else:
 | 
			
		||||
        ret = {} if isinstance(ret, dict) else None
 | 
			
		||||
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _gather_state_dict(
 | 
			
		||||
    state_dict: Dict[str, Any],
 | 
			
		||||
    *,
 | 
			
		||||
    pg: Optional[dist.ProcessGroup] = None,
 | 
			
		||||
    device: Optional[torch.device] = None,
 | 
			
		||||
    cpu_offload: bool = False,
 | 
			
		||||
    ranks_only: Tuple[int, ...] = tuple(),
 | 
			
		||||
    type_check: bool = True,
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Given a state_dict, this API gathers all the ShardedTensors or DTensors in
 | 
			
		||||
    the state_dict.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        state_dict (Dict[str, Any]): the target sharded state_dict.
 | 
			
		||||
        pg (Optional[dist.ProcessGroup]): the process group that is used to
 | 
			
		||||
            gather ShardedTensor. Note that gathering a DTensor will use
 | 
			
		||||
            the DeviceMesh. So this argument will be ignored when gathering a
 | 
			
		||||
            DTensor.
 | 
			
		||||
        device: (Optional[torch.device]): the device that is used to
 | 
			
		||||
            perform allgather for ShardedTensor. Note that gathering a DTensor
 | 
			
		||||
            will use the DeviceMesh. So this argument will be ignored when
 | 
			
		||||
            gathering a DTensor.
 | 
			
		||||
        cpu_offload (bool): whether to offload the tensors to CPU memory. The
 | 
			
		||||
            default value is False.
 | 
			
		||||
        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
 | 
			
		||||
            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
 | 
			
		||||
            have the same state_dicts. Other ranks will get empty state_dicts.
 | 
			
		||||
        type_check: (bool): check if the instance data type is a supported type
 | 
			
		||||
            that can be saved by DCP.  The current supported data types are
 | 
			
		||||
            torch.Tensor, DTensor, int, float, str, list, dict, None.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        The gathered state dictionary.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def sharded_tensor_func(value, pg, device, companion_obj):
 | 
			
		||||
        # ShardedTensor does not seem to record the original device type.
 | 
			
		||||
        # So if the tensor is moved to CPU, we won't know the original type.
 | 
			
		||||
        # As a result, we have to rely on the user to tell us the correct one.
 | 
			
		||||
        cpu_device = torch.device("cpu")
 | 
			
		||||
        output_tensor = _all_gather_sharded_tensor(value, pg, device)
 | 
			
		||||
        local_shard_device = (
 | 
			
		||||
            value.local_shards()[0].tensor.device
 | 
			
		||||
            if value.local_shards()
 | 
			
		||||
            else cpu_device
 | 
			
		||||
        )
 | 
			
		||||
        if output_tensor.device != local_shard_device:
 | 
			
		||||
            value = output_tensor.to(local_shard_device)
 | 
			
		||||
        else:
 | 
			
		||||
            value = output_tensor
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def dtensor_func(value, pg, device, companion_obj):
 | 
			
		||||
        if value.device != value.device_mesh.device_type:
 | 
			
		||||
            value = value.to(value.device_mesh.device_type)
 | 
			
		||||
        # FSDP all_gather: [Shard(0)] -> [Replicate()]
 | 
			
		||||
        # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
 | 
			
		||||
        # 2D FSDP + TP all_gather:
 | 
			
		||||
        # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
 | 
			
		||||
        # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
 | 
			
		||||
        placements = [Replicate() for _ in value.placements]
 | 
			
		||||
        value = value.redistribute(
 | 
			
		||||
            device_mesh=value.device_mesh,
 | 
			
		||||
            placements=placements,
 | 
			
		||||
        )
 | 
			
		||||
        # Call `wait()` to force the tensor to be synchronous with respect
 | 
			
		||||
        # to the main stream.
 | 
			
		||||
        # See the discussion in https://github.com/pytorch/pytorch/pull/117799.
 | 
			
		||||
        value = value.to_local()
 | 
			
		||||
        if isinstance(value, AsyncCollectiveTensor):
 | 
			
		||||
            value = value.wait()
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    return _iterate_state_dict(
 | 
			
		||||
        state_dict,
 | 
			
		||||
        sharded_tensor_func,
 | 
			
		||||
        dtensor_func,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        pg=pg,
 | 
			
		||||
        device=device,
 | 
			
		||||
        cpu_offload=cpu_offload,
 | 
			
		||||
        ranks_only=ranks_only,
 | 
			
		||||
        type_check=type_check,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _offload_state_dict_to_cpu(
 | 
			
		||||
    state_dict: Dict[str, Any],
 | 
			
		||||
    *,
 | 
			
		||||
    ranks_only: Tuple[int, ...] = tuple(),
 | 
			
		||||
    type_check: bool = True,
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Given a state_dict, this API offload all the tensors to CPU memory.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        state_dict (Dict[str, Any]): the target state_dict.
 | 
			
		||||
        pg (Optional[dist.ProcessGroup]): the process group that is used to
 | 
			
		||||
            gather ShardedTensor. Note that gathering a DTensor will use
 | 
			
		||||
            the DeviceMesh. So this argument will be ignored when gathering a
 | 
			
		||||
            DTensor.
 | 
			
		||||
        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
 | 
			
		||||
            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
 | 
			
		||||
            have the same state_dicts. Other ranks will get empty state_dicts.
 | 
			
		||||
        type_check: (bool): check if the instance data type is a supported type
 | 
			
		||||
            that can be saved by DCP.  The current supported data types are
 | 
			
		||||
            torch.Tensor, DTensor, int, float, str, list, dict, None.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        The gathered state dictionary.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    ret = _iterate_state_dict(
 | 
			
		||||
        state_dict,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        pg=None,
 | 
			
		||||
        device=None,
 | 
			
		||||
        cpu_offload=True,
 | 
			
		||||
        ranks_only=ranks_only,
 | 
			
		||||
        type_check=type_check,
 | 
			
		||||
    )
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _copy_state_dict(
 | 
			
		||||
    state_dict: Dict[str, Any],
 | 
			
		||||
    copy_state_dict: Dict[str, Any],
 | 
			
		||||
    non_blocking: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Copies all tensors in a given state dict into a different state_dict with the
 | 
			
		||||
    same structure.
 | 
			
		||||
 | 
			
		||||
    .. warning::
 | 
			
		||||
        It is expected by this function that state_dict and copy_state_dict share
 | 
			
		||||
        the same structure and data types.
 | 
			
		||||
 | 
			
		||||
    .. warning::
 | 
			
		||||
        The current supported data types are
 | 
			
		||||
            torch.Tensor, DTensor, int, float, str, list, dict, None.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        state_dict (Dict[str, Any]): the target state_dict.
 | 
			
		||||
        copy_state_dict (Dict[str, Any]):
 | 
			
		||||
            The state dict we are copying into. This state_dict must have exactly
 | 
			
		||||
             the same structure as the source `state_dict`.
 | 
			
		||||
        non_blocking: (bool): Whether copy ops should be performed asynchronously
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _iterate_state_dict(
 | 
			
		||||
        state_dict,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        pg=None,
 | 
			
		||||
        device=None,
 | 
			
		||||
        cpu_offload=False,
 | 
			
		||||
        ranks_only=tuple(),
 | 
			
		||||
        companion_obj=copy_state_dict,
 | 
			
		||||
        type_check=True,
 | 
			
		||||
        non_blocking=non_blocking,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_cpu_state_dict(
 | 
			
		||||
    state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Given a state_dict, create another state_dict with the same structure and elements.
 | 
			
		||||
    However, all tensors in the returned state_dict are new tensors on CPU. These
 | 
			
		||||
    tensors can be placed on pin_memory or share_memory based on the provided arguments.
 | 
			
		||||
 | 
			
		||||
    .. warning::
 | 
			
		||||
        Setting both `pin_memory` and `share_memory` to True significantly increases the
 | 
			
		||||
        latency of this method because of the nuances which require us to register memory
 | 
			
		||||
        as pinned directly as opposed to relying on the pin_memory cache allocator. This
 | 
			
		||||
        option should only be used for long lived tensors which are required to be shared.
 | 
			
		||||
        This is not the case as long as at least one of `pin_memory` or `share_memory` is
 | 
			
		||||
         set to False.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def tensor_func(
 | 
			
		||||
        obj: torch.Tensor,
 | 
			
		||||
        pg: Optional[dist.ProcessGroup],
 | 
			
		||||
        device: Optional[torch.device],
 | 
			
		||||
        _: Any,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        if len(obj.size()) == 0:
 | 
			
		||||
            return torch.tensor(0, dtype=obj.dtype)
 | 
			
		||||
 | 
			
		||||
        if share_memory:
 | 
			
		||||
            t = torch.empty(*tuple(obj.size()), dtype=obj.dtype).share_memory_()
 | 
			
		||||
            if pin_memory:
 | 
			
		||||
                succ = torch.cuda.cudart().cudaHostRegister(
 | 
			
		||||
                    t.data_ptr(),
 | 
			
		||||
                    t.numel() * t.element_size(),
 | 
			
		||||
                    1,  # lines up with 'cudaHostRegisterPortable'
 | 
			
		||||
                )
 | 
			
		||||
                assert (
 | 
			
		||||
                    succ == 0
 | 
			
		||||
                ), f"Pinning shared memory failed with error-code: {succ}"
 | 
			
		||||
            return t
 | 
			
		||||
        elif pin_memory:
 | 
			
		||||
            return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
 | 
			
		||||
        else:
 | 
			
		||||
            return torch.empty(*tuple(obj.size()), dtype=obj.dtype)
 | 
			
		||||
 | 
			
		||||
    ret = _iterate_state_dict(
 | 
			
		||||
        state_dict,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        _identity_func,
 | 
			
		||||
        tensor_func,
 | 
			
		||||
        pg=None,
 | 
			
		||||
        device=None,
 | 
			
		||||
        cpu_offload=False,
 | 
			
		||||
        ranks_only=tuple(),
 | 
			
		||||
        type_check=False,
 | 
			
		||||
    )
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_state_dict_similarity(
 | 
			
		||||
    state_dict: Dict[str, Any],
 | 
			
		||||
    compared_state_dict: Dict[str, Any],
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Given two state_dicts, check if the structures are the same. And
 | 
			
		||||
    if a [key, tensor] pair exist in one state_dict there must be
 | 
			
		||||
    the a corresponding pait, [key, other_tensor], in the other state_dict,
 | 
			
		||||
    where tensor and other_tensor have the same size and dtype.
 | 
			
		||||
 | 
			
		||||
    Return the check result.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def tensor_func(
 | 
			
		||||
        obj: torch.Tensor,
 | 
			
		||||
        pg: Optional[dist.ProcessGroup],
 | 
			
		||||
        device: Optional[torch.device],
 | 
			
		||||
        companion_obj: Any,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
 | 
			
		||||
            raise CompanionMismatch()
 | 
			
		||||
        return obj
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        _iterate_state_dict(
 | 
			
		||||
            state_dict,
 | 
			
		||||
            _identity_func,
 | 
			
		||||
            _identity_func,
 | 
			
		||||
            tensor_func,
 | 
			
		||||
            pg=None,
 | 
			
		||||
            device=None,
 | 
			
		||||
            cpu_offload=False,
 | 
			
		||||
            ranks_only=tuple(),
 | 
			
		||||
            companion_obj=compared_state_dict,
 | 
			
		||||
            type_check=False,
 | 
			
		||||
        )
 | 
			
		||||
    except CompanionMismatch:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    return True
 | 
			
		||||
@ -1171,7 +1171,7 @@ def init_process_group(
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        default_pg, _ = _new_process_group_helper(
 | 
			
		||||
            -1, -1, [], backend, None, group_name, timeout=timeout
 | 
			
		||||
            -1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg"
 | 
			
		||||
        )
 | 
			
		||||
        _update_default_pg(default_pg)
 | 
			
		||||
    else:
 | 
			
		||||
@ -1197,6 +1197,7 @@ def init_process_group(
 | 
			
		||||
            pg_options=pg_options,
 | 
			
		||||
            timeout=timeout,
 | 
			
		||||
            device_id=device_id,
 | 
			
		||||
            group_desc="default_pg"
 | 
			
		||||
        )
 | 
			
		||||
        _update_default_pg(default_pg)
 | 
			
		||||
 | 
			
		||||
@ -1257,6 +1258,7 @@ def _new_process_group_helper(
 | 
			
		||||
    timeout=None,
 | 
			
		||||
    pg_tag=None,
 | 
			
		||||
    device_id=None,
 | 
			
		||||
    group_desc=None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Create a new distributed process group.
 | 
			
		||||
@ -1289,6 +1291,8 @@ def _new_process_group_helper(
 | 
			
		||||
            _, prefix_store = _world.pg_map[existing_group]
 | 
			
		||||
            return existing_group, prefix_store
 | 
			
		||||
 | 
			
		||||
    group_desc = "undefined" if group_desc is None else group_desc
 | 
			
		||||
 | 
			
		||||
    # The list of group ranks is empty if we're creating the default group.
 | 
			
		||||
    is_default_group = len(global_ranks_in_group) == 0
 | 
			
		||||
 | 
			
		||||
@ -1375,6 +1379,7 @@ def _new_process_group_helper(
 | 
			
		||||
            if split_from:
 | 
			
		||||
                pg_options.split_from = split_from
 | 
			
		||||
                pg_options.split_color = _process_group_color(global_ranks_in_group)
 | 
			
		||||
            pg_options.group_name = group_name
 | 
			
		||||
            backend_class = ProcessGroupNCCL(
 | 
			
		||||
                backend_prefix_store, group_rank, group_size, pg_options)
 | 
			
		||||
            backend_type = ProcessGroup.BackendType.NCCL
 | 
			
		||||
@ -1461,9 +1466,11 @@ def _new_process_group_helper(
 | 
			
		||||
 | 
			
		||||
    # update global state
 | 
			
		||||
    assert group_name is not None
 | 
			
		||||
    assert group_desc is not None
 | 
			
		||||
    _world.pg_map[pg] = (backend, prefix_store)
 | 
			
		||||
    _world.pg_names[pg] = group_name
 | 
			
		||||
    pg._set_group_name(group_name)
 | 
			
		||||
    pg._set_group_desc(group_desc)
 | 
			
		||||
 | 
			
		||||
    _world.pg_backend_config[pg] = str(backend_config)
 | 
			
		||||
    # "" is the default tag for user PGs
 | 
			
		||||
@ -3614,7 +3621,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@_time_logger
 | 
			
		||||
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False):
 | 
			
		||||
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None):
 | 
			
		||||
    """
 | 
			
		||||
    Create a new distributed group.
 | 
			
		||||
 | 
			
		||||
@ -3655,6 +3662,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
 | 
			
		||||
            barrier at the end of the process group creation. This is different
 | 
			
		||||
            in that non-member ranks don't need to call into API and don't
 | 
			
		||||
            join the barrier.
 | 
			
		||||
        group_desc (str, optional): a string to describe the process group.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        A handle of distributed group that can be given to collective calls or None if the rank is not part of ``ranks``.
 | 
			
		||||
@ -3669,7 +3677,15 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
 | 
			
		||||
    multiple overlaping process groups. To avoid that, make sure all ranks follow the
 | 
			
		||||
    same global creation order.
 | 
			
		||||
    """
 | 
			
		||||
    return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization)
 | 
			
		||||
    return _new_group_with_tag(
 | 
			
		||||
        ranks,
 | 
			
		||||
        timeout,
 | 
			
		||||
        backend,
 | 
			
		||||
        pg_options,
 | 
			
		||||
        None,
 | 
			
		||||
        use_local_synchronization=use_local_synchronization,
 | 
			
		||||
        group_desc=group_desc,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def _new_group_with_tag(
 | 
			
		||||
    ranks=None,
 | 
			
		||||
@ -3677,7 +3693,8 @@ def _new_group_with_tag(
 | 
			
		||||
    backend=None,
 | 
			
		||||
    pg_options=None,
 | 
			
		||||
    pg_tag=None,
 | 
			
		||||
    use_local_synchronization=False
 | 
			
		||||
    use_local_synchronization=False,
 | 
			
		||||
    group_desc=None
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Variant of ``new_group`` that exposes tag creation.
 | 
			
		||||
@ -3749,7 +3766,8 @@ def _new_group_with_tag(
 | 
			
		||||
        group_name,
 | 
			
		||||
        pg_options=pg_options,
 | 
			
		||||
        timeout=timeout,
 | 
			
		||||
        pg_tag=pg_tag
 | 
			
		||||
        pg_tag=pg_tag,
 | 
			
		||||
        group_desc=group_desc
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Create the global rank to group rank mapping
 | 
			
		||||
@ -3789,6 +3807,7 @@ def new_subgroups(
 | 
			
		||||
    timeout=None,
 | 
			
		||||
    backend=None,
 | 
			
		||||
    pg_options=None,
 | 
			
		||||
    group_desc=None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Create subgroups of equal size.
 | 
			
		||||
@ -3841,6 +3860,8 @@ def new_subgroups(
 | 
			
		||||
            the construction of specific process groups. i.e. for the ``nccl``
 | 
			
		||||
            backend, ``is_high_priority_stream`` can be specified so that
 | 
			
		||||
            process group can pick up high priority cuda streams.
 | 
			
		||||
        group_desc (str, optional): A string describing the group. Each subgroup will
 | 
			
		||||
            inherit its group_desc
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        The subgroup containing the current rank, and all the subgroups used for cleanup.
 | 
			
		||||
@ -3886,6 +3907,7 @@ def new_subgroups(
 | 
			
		||||
            timeout=timeout,
 | 
			
		||||
            backend=backend,
 | 
			
		||||
            pg_options=pg_options,
 | 
			
		||||
            group_desc=group_desc,
 | 
			
		||||
        )
 | 
			
		||||
        subgroups.append(subgroup)
 | 
			
		||||
 | 
			
		||||
@ -3905,6 +3927,7 @@ def new_subgroups_by_enumeration(
 | 
			
		||||
    timeout=None,
 | 
			
		||||
    backend=None,
 | 
			
		||||
    pg_options=None,
 | 
			
		||||
    group_desc=None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Create subgroups by dividing the global world.
 | 
			
		||||
@ -3945,6 +3968,8 @@ def new_subgroups_by_enumeration(
 | 
			
		||||
            the construction of specific process groups. i.e. for the ``nccl``
 | 
			
		||||
            backend, ``is_high_priority_stream`` can be specified so that
 | 
			
		||||
            process group can pick up high priority cuda streams.
 | 
			
		||||
        group_desc (str, optional): A string describing the group. Each subgroup will
 | 
			
		||||
            inherit its group_desc.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        The subgroup containing the current rank, and all the subgroups used for cleanup.
 | 
			
		||||
@ -3973,6 +3998,7 @@ def new_subgroups_by_enumeration(
 | 
			
		||||
            timeout=timeout,
 | 
			
		||||
            backend=backend,
 | 
			
		||||
            pg_options=pg_options,
 | 
			
		||||
            group_desc=group_desc,
 | 
			
		||||
        )
 | 
			
		||||
        subgroups.append(subgroup)
 | 
			
		||||
        my_rank = get_rank()
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ def tail_logfile(
 | 
			
		||||
            return
 | 
			
		||||
        time.sleep(interval_sec)
 | 
			
		||||
 | 
			
		||||
    with open(file) as fp:
 | 
			
		||||
    with open(file, errors="replace") as fp:
 | 
			
		||||
        while True:
 | 
			
		||||
            line = fp.readline()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user