diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index ddcc5f5cc148..e152feba9ccc 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -4205,8 +4205,11 @@ std::atomic MemPool::uuid_{1}; MemPool::MemPool( CUDACachingAllocator::CUDAAllocator* allocator, bool is_user_created, - bool use_on_oom) - : allocator_(allocator), is_user_created_(is_user_created) { + bool use_on_oom, + bool symmetric) + : allocator_(allocator), + is_user_created_(is_user_created), + symmetric_(symmetric) { if (is_user_created_) { id_ = {0, uid_++}; } else { @@ -4229,6 +4232,10 @@ MempoolId_t MemPool::id() { return id_; } +bool MemPool::is_symmetric() { + return symmetric_; +} + CUDACachingAllocator::CUDAAllocator* MemPool::allocator() { return allocator_; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 23d99b0c2f38..a6fa61110d67 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -535,7 +535,8 @@ struct C10_CUDA_API MemPool { MemPool( CUDACachingAllocator::CUDAAllocator* allocator = nullptr, bool is_user_created = true, - bool use_on_oom = false); + bool use_on_oom = false, + bool symmetric = false); MemPool(const MemPool&) = delete; MemPool(MemPool&&) = default; MemPool& operator=(const MemPool&) = delete; @@ -543,6 +544,7 @@ struct C10_CUDA_API MemPool { ~MemPool(); MempoolId_t id(); + bool is_symmetric(); CUDACachingAllocator::CUDAAllocator* allocator(); int use_count(); c10::DeviceIndex device(); @@ -554,6 +556,7 @@ struct C10_CUDA_API MemPool { CUDACachingAllocator::CUDAAllocator* allocator_; bool is_user_created_; MempoolId_t id_; + bool symmetric_; c10::DeviceIndex device_; }; diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 032700e1635a..4d2a9565880b 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3084,7 +3084,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase): self._run_invalid_nccl_blocking_wait_env("4294967295") -class NcclUserBufferRegistrationTest(MultiProcessTestCase): +class NcclRegistrationTest(MultiProcessTestCase): def setUp(self): super().setUp() # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests @@ -3095,7 +3095,7 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase): os.environ["NCCL_DEBUG"] = "INFO" os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS" if torch.cuda.nccl.version() >= (2, 24, 3): - os.environ["NCCL_DEBUG_SUBSYS"] = "REG" + os.environ["NCCL_DEBUG_SUBSYS"] = "REG,TUNING" os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name self._spawn_processes() @@ -3151,6 +3151,48 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase): else: self.assertRegex(nccl_debug_file_content, "local-registered") + @requires_nccl() + @requires_nccl_version((2, 27), "Need NCCL 2.27 for window registration") + @skip_if_lt_x_gpu(4) + @requires_multicast_support() + def test_nccl_window_registration(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + device = torch.device(f"cuda:{self.rank}") + torch.cuda.set_device(self.rank) + pg = c10d.distributed_c10d._get_default_group() + backend = pg._get_backend(torch.device(device)) + + # Use NCCL memory allocator + # enable symmetric memory usage in NCCL + pool = torch.cuda.MemPool(backend.mem_allocator, symm_mem=True) + + # allocate memory with ncclMemAlloc + # note: symmetric kernels are not available for dtypes like torch.int64 + with torch.cuda.use_mem_pool(pool): + tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32) + + # register buffers to NCCL + backend.register_mem_pool(pool) + + # allreduce now should use NVIDIA Switches + pg.allreduce(tensor).wait() + torch.cuda.synchronize(device=device) + + # de-register buffers from NCCL + backend.deregister_mem_pool(pool) + + # clean up memory + del tensor, pool + + with open(os.environ["NCCL_DEBUG_FILE"]) as f: + nccl_debug_file_content = f.read() + # if buffers were registered and symmetric kernels ran, NCCL_DEBUG + # should show successful registration in debug output + self.assertRegex(nccl_debug_file_content, "[Symmetric]") + class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): @property diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index c267b0b9451a..260de8e33675 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2300,10 +2300,13 @@ class _MemPool: allocator: _cuda_CUDAAllocator | None = None, is_user_created: _bool = True, use_on_oom: _bool = False, + symmetric: _bool = False, ) -> None: ... @property def id(self) -> tuple[_int, _int]: ... @property + def is_symmetric(self) -> _bool: ... + @property def allocator(self) -> _cuda_CUDAAllocator | None: ... def use_count(self) -> _int: ... diff --git a/torch/csrc/cuda/MemPool.cpp b/torch/csrc/cuda/MemPool.cpp index b651a4b5e68a..feb22e360bb9 100644 --- a/torch/csrc/cuda/MemPool.cpp +++ b/torch/csrc/cuda/MemPool.cpp @@ -16,12 +16,15 @@ void THCPMemPool_init(PyObject* module) { .def( py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator, bool is_user_created, - bool use_on_oom) { + bool use_on_oom, + bool symmetric) { torch::utils::device_lazy_init(at::kCUDA); return std::make_shared<::c10::cuda::MemPool>( - allocator, is_user_created, use_on_oom); + allocator, is_user_created, use_on_oom, symmetric); })) .def_property_readonly("id", &::c10::cuda::MemPool::id) + .def_property_readonly( + "is_symmetric", &::c10::cuda::MemPool::is_symmetric) .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) .def("use_count", &::c10::cuda::MemPool::use_count); } diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 7703fab06084..a9178b3bbca8 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -350,7 +350,8 @@ ncclResult_t NCCLComm::checkForNcclError() { ncclResult_t NCCLComm::registerSegment( void* ptr, size_t size, - bool errorOnRereg /*=true*/) { + bool errorOnRereg, /*=true*/ + bool window /*=false*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator @@ -371,6 +372,30 @@ ncclResult_t NCCLComm::registerSegment( void* handle = nullptr; // Use getNcclComm to make sure comm is ready before calling nccl APIs auto comm = getNcclComm(); +#ifdef NCCL_HAS_COMM_WINDOW_REGISTER + if (window) { + C10D_NCCL_CHECK( + ncclCommWindowRegister( + comm, ptr, size, (ncclWindow_t*)&handle, NCCL_WIN_COLL_SYMMETRIC), + c10::str( + "Failed to window register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + comm)); + } else { + C10D_NCCL_CHECK( + ncclCommRegister(comm, ptr, size, &handle), + c10::str( + "Failed to register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + comm)); + } +#else C10D_NCCL_CHECK( ncclCommRegister(comm, ptr, size, &handle), c10::str( @@ -380,6 +405,7 @@ ncclResult_t NCCLComm::registerSegment( size, " on ncclComm_ ", comm)); +#endif registeredSegmentHandles_[ptr] = handle; return ncclSuccess; #else @@ -387,7 +413,7 @@ ncclResult_t NCCLComm::registerSegment( #endif } -ncclResult_t NCCLComm::deregisterSegment(void* ptr) { +ncclResult_t NCCLComm::deregisterSegment(void* ptr, bool window /*false*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( @@ -400,6 +426,29 @@ ncclResult_t NCCLComm::deregisterSegment(void* ptr) { void* handle = registeredSegmentHandles_[ptr]; // Use getNcclComm to make sure comm is ready before calling nccl APIs auto comm = getNcclComm(); +#ifdef NCCL_HAS_COMM_WINDOW_REGISTER + if (window) { + C10D_NCCL_CHECK( + ncclCommWindowDeregister(comm, (ncclWindow_t)handle), + c10::str( + "Failed to window deregister segment handle ", + handle, + ", with ptr ", + ptr, + " on ncclComm_ ", + comm)); + } else { + C10D_NCCL_CHECK( + ncclCommDeregister(comm, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + ", with ptr ", + ptr, + " on ncclComm_ ", + comm)); + } +#else C10D_NCCL_CHECK( ncclCommDeregister(comm, handle), c10::str( @@ -409,6 +458,7 @@ ncclResult_t NCCLComm::deregisterSegment(void* ptr) { ptr, " on ncclComm_ ", comm)); +#endif registeredSegmentHandles_.erase(ptr); return ncclSuccess; #else diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 2780a7196349..4061b3f5788c 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -63,6 +63,10 @@ static_assert( #define NCCL_HAS_COMM_REGISTER #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COMM_WINDOW_REGISTER +#endif + #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) #define NCCL_HAS_MEM_ALLOC #endif @@ -341,9 +345,10 @@ class NCCLComm { ncclResult_t registerSegment( void* ptr, size_t size, - bool errorOnRereg = true); + bool errorOnRereg = true, + bool window = false); - ncclResult_t deregisterSegment(void* ptr); + ncclResult_t deregisterSegment(void* ptr, bool window = false); std::string repr() const; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index c0cfe6f0caa4..8e881d3f2617 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1182,6 +1182,14 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { // register future segments allocated in this pool (this call is idempotent). attachAllocatorHooks(); auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); + // TODO: + // if(pool->is_symmetric()) { + // Allgather to verify len(mempool.snapshot.segments) matches across GPUs + // Allgather to verify mempool.alloc_request_counter matches across GPUs + // add alloc_request_counter per mempool (How many allocations a mempool has + // served during its lifetime) this should guarantee pool is used in a + // symmetric/SPMD manner + // } for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -1190,7 +1198,9 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(segmentInfo.address), segmentInfo.total_size, - /*errorOnRereg=*/false); // ignores reregistration error + /*errorOnRereg=*/false, // ignores reregistration error + /*window=*/pool->is_symmetric()); // whether to use NCCL symmetric + // memory } } @@ -1221,7 +1231,8 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { segmentInfo.device == pool->device(), "Mismatch between CUDA memory segment device and pool's device"); // NOLINTNEXTLINE(performance-no-int-to-ptr) - ncclComm->deregisterSegment(reinterpret_cast(segmentInfo.address)); + ncclComm->deregisterSegment( + reinterpret_cast(segmentInfo.address), pool->is_symmetric()); } } diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 5e1914e701c5..ef474df7cfb0 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -1163,21 +1163,28 @@ class MemPool(_MemPool): use_on_oom(bool): a bool that indicates if this pool can be used as a last resort if a memory allocation outside of the pool fails due to Out Of Memory. This is False by default. - + symmetric(bool): a bool that indicates if this pool is symmetrical + across ranks. This is False by default. """ def __init__( self, allocator: Optional[_cuda_CUDAAllocator] = None, use_on_oom: bool = False, + symmetric: bool = False, ): - super().__init__(allocator, True, use_on_oom) + super().__init__(allocator, True, use_on_oom, symmetric) @property def id(self) -> tuple[int, int]: r"""Returns the ID of this pool as a tuple of two ints.""" return super().id + @property + def is_symmetric(self) -> bool: + r"""Returns whether this pool is used for NCCL's symmetric memory.""" + return super().is_symmetric + @property def allocator(self) -> Optional[_cuda_CUDAAllocator]: r"""Returns the allocator this MemPool routes allocations to."""