mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DDP] Use NCCL allocated memory for gradient bucket (#146589)
So that NVLink SHARP comes with zero-copy on H100+ platforms, for DDP applications. Less SM usage, less memory contention between NCCL kernel and compute kernels. Added env `DDP_DISABLE_COMM_MEM` as a back-out option: ``` An environment variable to disable comm-optimized memory pool. Default is 0, which means comm-optimized memory pool is enabled. Users can set it to 1 in case of seeing regression or OOM (because this comm MemPool may not share space with regular compute MemPool). ``` Differential Revision: [D69297766](https://our.internmc.facebook.com/intern/diff/D69297766) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146589 Approved by: https://github.com/syed-ahmed, https://github.com/c-p-i-o, https://github.com/fduwjj
This commit is contained in:
@ -2007,7 +2007,7 @@ class DistributedDataParallelTest(
|
||||
replica_devices = [dev0]
|
||||
# Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
|
||||
layer_devs = dev0
|
||||
local_batch_size = 8
|
||||
local_batch_size = 16
|
||||
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
|
||||
|
||||
@requires_nccl()
|
||||
@ -2021,7 +2021,7 @@ class DistributedDataParallelTest(
|
||||
replica_devices = None
|
||||
# Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device.
|
||||
layer_devs = [dev0] * 2 + [dev1] * 2
|
||||
local_batch_size = 8
|
||||
local_batch_size = 16
|
||||
self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
|
||||
|
||||
@requires_nccl()
|
||||
|
@ -417,6 +417,21 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
"Backend ", getBackendName(), " does not support getMemAllocator"));
|
||||
}
|
||||
|
||||
// Allocate tensor (aten::empty) from backend's communication-optimized memory
|
||||
// pool
|
||||
virtual at::Tensor allocateTensor(long size, at::TensorOptions options = {}) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str(
|
||||
"Backend ", getBackendName(), " does not support allocateTensor"));
|
||||
}
|
||||
|
||||
// Returns true if backend supports tensor allocation
|
||||
virtual bool supportsTensorAlloc() {
|
||||
// Change to true in concrete backend if supported
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
|
@ -340,19 +340,26 @@ ncclResult_t NCCLComm::checkForNcclError() {
|
||||
#endif
|
||||
}
|
||||
|
||||
ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) {
|
||||
ncclResult_t NCCLComm::registerSegment(
|
||||
void* ptr,
|
||||
size_t size,
|
||||
bool errorOnRereg /*=true*/) {
|
||||
LockType lock(mutex_);
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
// We register only segments from cache allocator
|
||||
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
|
||||
// maps to a unique handle and should not be registered before the current
|
||||
// ptr is deregistered and freed.
|
||||
TORCH_CHECK(
|
||||
registeredSegmentHandles_.count(ptr) == 0,
|
||||
"Segment with ptr ",
|
||||
ptr,
|
||||
" has already been registered on ncclComm_ ",
|
||||
ncclComm_);
|
||||
if (registeredSegmentHandles_.count(ptr) > 0) {
|
||||
TORCH_CHECK(
|
||||
!errorOnRereg,
|
||||
"Segment with ptr ",
|
||||
ptr,
|
||||
" has already been registered on ncclComm_ ",
|
||||
ncclComm_);
|
||||
// Skip below
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
void* handle = nullptr;
|
||||
// Use getNcclComm to make sure comm is ready before calling nccl APIs
|
||||
|
@ -284,7 +284,10 @@ class NCCLComm {
|
||||
|
||||
ncclResult_t checkForNcclError();
|
||||
|
||||
ncclResult_t registerSegment(void* ptr, size_t size);
|
||||
ncclResult_t registerSegment(
|
||||
void* ptr,
|
||||
size_t size,
|
||||
bool errorOnRereg = true);
|
||||
|
||||
ncclResult_t deregisterSegment(void* ptr);
|
||||
|
||||
|
@ -1175,7 +1175,8 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
|
||||
ncclComm->registerSegment(
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
reinterpret_cast<void*>(segmentInfo.address),
|
||||
segmentInfo.total_size);
|
||||
segmentInfo.total_size,
|
||||
/*errorOnRereg=*/false); // ignores reregistration error
|
||||
}
|
||||
}
|
||||
|
||||
@ -1455,6 +1456,14 @@ void ProcessGroupNCCL::shutdown() {
|
||||
// Use long interval to avoid acquiring CPU too frequently
|
||||
ncclComm->waitReady(true);
|
||||
}
|
||||
// Deregister memory pool after finalizing all collectives
|
||||
if (memPool_) {
|
||||
try {
|
||||
deregisterMemPool(memPool_.get());
|
||||
} catch (...) {
|
||||
LOG(ERROR) << logPrefix() << "Failed to deregister memory pool, ignoring";
|
||||
}
|
||||
}
|
||||
// Tell watchdog to (1) flush its queue and (2) do not use comm objects
|
||||
// anymore because I am going to destroy them now
|
||||
LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread.";
|
||||
@ -5422,6 +5431,46 @@ std::shared_ptr<c10::Allocator> ProcessGroupNCCL::getMemAllocator() {
|
||||
return ncclMemAllocator;
|
||||
}
|
||||
|
||||
at::Tensor ProcessGroupNCCL::allocateTensor(
|
||||
long size,
|
||||
at::TensorOptions options) {
|
||||
// Some checks
|
||||
TORCH_CHECK_VALUE(options.has_device(), "Tensor options must include device");
|
||||
auto device = options.device();
|
||||
TORCH_CHECK_VALUE(
|
||||
device.is_cuda(),
|
||||
"NCCL tensor allocator expects cuda type but got " + c10::str(device))
|
||||
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
|
||||
// Create memory pool
|
||||
if (!memPool_) {
|
||||
// Needs a CUDAAllocator
|
||||
auto allocator =
|
||||
reinterpret_cast<c10::cuda::CUDACachingAllocator::CUDAAllocator*>(
|
||||
getMemAllocator().get());
|
||||
// Pool is created
|
||||
memPool_ = std::make_unique<c10::cuda::MemPool>(allocator);
|
||||
LOG(INFO) << logPrefix() << "Created memory pool";
|
||||
}
|
||||
|
||||
// Allocate tensor under this MemPool's context
|
||||
auto ctx = c10::cuda::MemPoolContext(memPool_.get());
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
memPool_->device(), memPool_->id(), [](cudaStream_t) { return true; });
|
||||
at::Tensor tensor = at::empty({size}, options);
|
||||
// Also need to ncclCommRegister the pool in case new segments are created;
|
||||
// reregistration of old segments will be ignored
|
||||
registerMemPool(memPool_.get());
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(
|
||||
memPool_->device(), memPool_->id());
|
||||
c10::cuda::CUDACachingAllocator::releasePool(
|
||||
memPool_->device(), memPool_->id());
|
||||
LOG(INFO) << logPrefix() << "Allocated tensor of size " << size
|
||||
<< " from memory pool";
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -774,6 +774,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
||||
|
||||
// Allocate tensor from communication-optimized memory pool
|
||||
at::Tensor allocateTensor(long size, at::TensorOptions options = {}) override;
|
||||
|
||||
bool supportsTensorAlloc() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Performs NCCL user buffer registration for all buffers in
|
||||
// the given MemPool
|
||||
void registerMemPool(c10::cuda::MemPool* pool);
|
||||
@ -1294,6 +1301,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Internal cached value: use NCCL non-blocking API mode or not.
|
||||
// Use `useNonblocking()` method instead of accessing this variable directly.
|
||||
std::optional<bool> useNonblocking_{std::nullopt};
|
||||
|
||||
// Communication-optimized memory pool associated with this PG
|
||||
std::unique_ptr<c10::cuda::MemPool> memPool_ = nullptr;
|
||||
};
|
||||
|
||||
// Dumps the NCCL comm traces and additional information about the Process
|
||||
|
@ -1157,14 +1157,44 @@ void Reducer::initialize_buckets(
|
||||
offset += length;
|
||||
}
|
||||
|
||||
// Allocate the bucket's flattened `gradients` tensor.
|
||||
// Make gradient type in the reduced precision if mixed precision is
|
||||
// enabled. This ensures that the type is correct when e.g. rebuilding
|
||||
// buckets.
|
||||
if (mixed_precision_param_dtype_.has_value()) {
|
||||
options = options.dtype(mixed_precision_param_dtype_);
|
||||
}
|
||||
bucket.gradients = at::empty({static_cast<long>(offset)}, options);
|
||||
|
||||
// Allocate the bucket's flattened `gradients` tensor.
|
||||
auto bucketSize = static_cast<long>(offset);
|
||||
// Check if we can use comm-optimized memory pool to allocate tensor
|
||||
c10::intrusive_ptr<Backend> backend = nullptr;
|
||||
// An environment variable to disable comm-optimized memory pool.
|
||||
// Default is 0, which means comm-optimized memory pool is enabled.
|
||||
// Users can set it to 1 in case of seeing regression or OOM (because this
|
||||
// comm MemPool may not share space with regular compute MemPool).
|
||||
bool ddpDisableCommMem =
|
||||
(getCvarString({"DDP_DISABLE_COMM_MEM"}, "0") == "1");
|
||||
try {
|
||||
backend = process_group_->getDefaultBackend();
|
||||
} catch (...) {
|
||||
// Sometimes the backend type can be `UNDEFINED` rather than `NCCL` or
|
||||
// `GLOO`. In this case, we just fall back to the regular way of
|
||||
// creating tensor
|
||||
LOG(INFO)
|
||||
<< "Reducer: default comm backend not found, skipping bucket memory optimization";
|
||||
}
|
||||
if (ddpDisableCommMem == 0 && backend != nullptr &&
|
||||
backend->supportsTensorAlloc()) {
|
||||
// Comm-optimized memory pool is available, use it to allocate tensor
|
||||
LOG(INFO)
|
||||
<< "Reducer: found comm-optimized memory allocator, using it to create bucket";
|
||||
bucket.gradients = backend->allocateTensor(bucketSize, options);
|
||||
} else {
|
||||
// Plain creation of tensor
|
||||
LOG(INFO)
|
||||
<< "Reducer: comm-optimized memory allocator not found, using regular one";
|
||||
bucket.gradients = at::empty({bucketSize}, options);
|
||||
}
|
||||
|
||||
// Note: "Gradient Layout Contract"
|
||||
//
|
||||
|
Reference in New Issue
Block a user