diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2ab444a4b689..522b6815ada3 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2007,7 +2007,7 @@ class DistributedDataParallelTest( replica_devices = [dev0] # Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device. layer_devs = dev0 - local_batch_size = 8 + local_batch_size = 16 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) @requires_nccl() @@ -2021,7 +2021,7 @@ class DistributedDataParallelTest( replica_devices = None # Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device. layer_devs = [dev0] * 2 + [dev1] * 2 - local_batch_size = 8 + local_batch_size = 16 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) @requires_nccl() diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 9d188c9c26d6..ff83d687f8a1 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -417,6 +417,21 @@ class TORCH_API Backend : public torch::CustomClassHolder { "Backend ", getBackendName(), " does not support getMemAllocator")); } + // Allocate tensor (aten::empty) from backend's communication-optimized memory + // pool + virtual at::Tensor allocateTensor(long size, at::TensorOptions options = {}) { + TORCH_CHECK( + false, + c10::str( + "Backend ", getBackendName(), " does not support allocateTensor")); + } + + // Returns true if backend supports tensor allocation + virtual bool supportsTensorAlloc() { + // Change to true in concrete backend if supported + return false; + } + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 9b5c59624795..99fc244af023 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -340,19 +340,26 @@ ncclResult_t NCCLComm::checkForNcclError() { #endif } -ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) { +ncclResult_t NCCLComm::registerSegment( + void* ptr, + size_t size, + bool errorOnRereg /*=true*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always // maps to a unique handle and should not be registered before the current // ptr is deregistered and freed. - TORCH_CHECK( - registeredSegmentHandles_.count(ptr) == 0, - "Segment with ptr ", - ptr, - " has already been registered on ncclComm_ ", - ncclComm_); + if (registeredSegmentHandles_.count(ptr) > 0) { + TORCH_CHECK( + !errorOnRereg, + "Segment with ptr ", + ptr, + " has already been registered on ncclComm_ ", + ncclComm_); + // Skip below + return ncclSuccess; + } void* handle = nullptr; // Use getNcclComm to make sure comm is ready before calling nccl APIs diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 1ec814948562..c7cd0a30924e 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -284,7 +284,10 @@ class NCCLComm { ncclResult_t checkForNcclError(); - ncclResult_t registerSegment(void* ptr, size_t size); + ncclResult_t registerSegment( + void* ptr, + size_t size, + bool errorOnRereg = true); ncclResult_t deregisterSegment(void* ptr); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d69fb2f5c36a..cd9363ec337f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1175,7 +1175,8 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { ncclComm->registerSegment( // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(segmentInfo.address), - segmentInfo.total_size); + segmentInfo.total_size, + /*errorOnRereg=*/false); // ignores reregistration error } } @@ -1455,6 +1456,14 @@ void ProcessGroupNCCL::shutdown() { // Use long interval to avoid acquiring CPU too frequently ncclComm->waitReady(true); } + // Deregister memory pool after finalizing all collectives + if (memPool_) { + try { + deregisterMemPool(memPool_.get()); + } catch (...) { + LOG(ERROR) << logPrefix() << "Failed to deregister memory pool, ignoring"; + } + } // Tell watchdog to (1) flush its queue and (2) do not use comm objects // anymore because I am going to destroy them now LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread."; @@ -5422,6 +5431,46 @@ std::shared_ptr ProcessGroupNCCL::getMemAllocator() { return ncclMemAllocator; } +at::Tensor ProcessGroupNCCL::allocateTensor( + long size, + at::TensorOptions options) { + // Some checks + TORCH_CHECK_VALUE(options.has_device(), "Tensor options must include device"); + auto device = options.device(); + TORCH_CHECK_VALUE( + device.is_cuda(), + "NCCL tensor allocator expects cuda type but got " + c10::str(device)) + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Create memory pool + if (!memPool_) { + // Needs a CUDAAllocator + auto allocator = + reinterpret_cast( + getMemAllocator().get()); + // Pool is created + memPool_ = std::make_unique(allocator); + LOG(INFO) << logPrefix() << "Created memory pool"; + } + + // Allocate tensor under this MemPool's context + auto ctx = c10::cuda::MemPoolContext(memPool_.get()); + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + memPool_->device(), memPool_->id(), [](cudaStream_t) { return true; }); + at::Tensor tensor = at::empty({size}, options); + // Also need to ncclCommRegister the pool in case new segments are created; + // reregistration of old segments will be ignored + registerMemPool(memPool_.get()); + c10::cuda::CUDACachingAllocator::endAllocateToPool( + memPool_->device(), memPool_->id()); + c10::cuda::CUDACachingAllocator::releasePool( + memPool_->device(), memPool_->id()); + LOG(INFO) << logPrefix() << "Allocated tensor of size " << size + << " from memory pool"; + return tensor; +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 002b3a1a1433..185d9bebe6eb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -774,6 +774,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::shared_ptr getMemAllocator() override; + // Allocate tensor from communication-optimized memory pool + at::Tensor allocateTensor(long size, at::TensorOptions options = {}) override; + + bool supportsTensorAlloc() override { + return true; + } + // Performs NCCL user buffer registration for all buffers in // the given MemPool void registerMemPool(c10::cuda::MemPool* pool); @@ -1294,6 +1301,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Internal cached value: use NCCL non-blocking API mode or not. // Use `useNonblocking()` method instead of accessing this variable directly. std::optional useNonblocking_{std::nullopt}; + + // Communication-optimized memory pool associated with this PG + std::unique_ptr memPool_ = nullptr; }; // Dumps the NCCL comm traces and additional information about the Process diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 03c1380bfe79..800269fe14ef 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1157,14 +1157,44 @@ void Reducer::initialize_buckets( offset += length; } - // Allocate the bucket's flattened `gradients` tensor. // Make gradient type in the reduced precision if mixed precision is // enabled. This ensures that the type is correct when e.g. rebuilding // buckets. if (mixed_precision_param_dtype_.has_value()) { options = options.dtype(mixed_precision_param_dtype_); } - bucket.gradients = at::empty({static_cast(offset)}, options); + + // Allocate the bucket's flattened `gradients` tensor. + auto bucketSize = static_cast(offset); + // Check if we can use comm-optimized memory pool to allocate tensor + c10::intrusive_ptr backend = nullptr; + // An environment variable to disable comm-optimized memory pool. + // Default is 0, which means comm-optimized memory pool is enabled. + // Users can set it to 1 in case of seeing regression or OOM (because this + // comm MemPool may not share space with regular compute MemPool). + bool ddpDisableCommMem = + (getCvarString({"DDP_DISABLE_COMM_MEM"}, "0") == "1"); + try { + backend = process_group_->getDefaultBackend(); + } catch (...) { + // Sometimes the backend type can be `UNDEFINED` rather than `NCCL` or + // `GLOO`. In this case, we just fall back to the regular way of + // creating tensor + LOG(INFO) + << "Reducer: default comm backend not found, skipping bucket memory optimization"; + } + if (ddpDisableCommMem == 0 && backend != nullptr && + backend->supportsTensorAlloc()) { + // Comm-optimized memory pool is available, use it to allocate tensor + LOG(INFO) + << "Reducer: found comm-optimized memory allocator, using it to create bucket"; + bucket.gradients = backend->allocateTensor(bucketSize, options); + } else { + // Plain creation of tensor + LOG(INFO) + << "Reducer: comm-optimized memory allocator not found, using regular one"; + bucket.gradients = at::empty({bucketSize}, options); + } // Note: "Gradient Layout Contract" //