[c10d] Add NCCL memory allocator (#145675)

This PR implements a small UI improvement over #133603.

It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it.

UI:
```
pool = torch.cuda.MemPool(backend.mem_allocator)
with torch.cuda.use_mem_pool(pool):
    tensor = torch.arange(1024 * 1024 * 2, device=device)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145675
Approved by: https://github.com/syed-ahmed, https://github.com/wconstab
This commit is contained in:
Ke Wen
2025-01-30 10:10:58 -08:00
committed by PyTorch MergeBot
parent 7796e308d0
commit 51ee9b154e
7 changed files with 68 additions and 43 deletions

View File

@ -67,7 +67,6 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
TestCase,
)
from torch.utils.cpp_extension import load_inline
if TEST_WITH_DEV_DBG_ASAN:
@ -3104,40 +3103,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def createNcclAllocator(self):
nccl_allocator_source = """
#include <torch/extension.h>
#include <nccl.h>
#include <iostream>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* nccl_alloc(size_t size, int device, void* stream) {
std::cout << "Using ncclMemAlloc" << std::endl;
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
C10_EXPORT void nccl_free(void* ptr, size_t size, int device, void* stream) {
std::cout << "Using ncclMemFree" << std::endl;
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
nccl_allocator_libname = "nccl_allocator"
nccl_allocator = load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
is_python_module=False,
keep_intermediates=False,
verbose=True,
)
return nccl_allocator
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
@ -3172,13 +3137,9 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase):
torch.cuda.set_device(self.rank)
pg = c10d.distributed_c10d._get_default_group()
backend = pg._get_backend(torch.device(device))
allocator_path = self.createNcclAllocator()
allocator = torch.cuda.memory.CUDAPluggableAllocator(
allocator_path,
"nccl_alloc",
"nccl_free",
)
pool = torch.cuda.MemPool(allocator.allocator())
# Use NCCL memory allocator
pool = torch.cuda.MemPool(backend.mem_allocator)
# allocate memory with ncclMemAlloc
with torch.cuda.use_mem_pool(pool):

View File

@ -296,6 +296,8 @@ class Backend:
def _set_sequence_number_for_group(self) -> None: ...
def _set_default_timeout(self, timeout: timedelta) -> None: ...
def get_error(self) -> ErrorType: ...
@property
def mem_allocator(self) -> Any: ...
class ProcessGroup:
class BackendType(Enum):

View File

@ -5,6 +5,7 @@
#include <vector>
#include <ATen/ATen.h>
#include <c10/core/Allocator.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/distributed/c10d/Types.hpp>
@ -409,6 +410,13 @@ class TORCH_API Backend : public torch::CustomClassHolder {
c10::str("Backend ", getBackendName(), " does not support getError"));
}
virtual std::shared_ptr<c10::Allocator> getMemAllocator() {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not support getMemAllocator"));
}
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.

View File

@ -86,6 +86,10 @@ static_assert(
#define NCCL_HAS_COMM_REGISTER
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0)
#define NCCL_HAS_MEM_ALLOC
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \

View File

@ -20,6 +20,7 @@
#include <c10/util/WaitCounter.h>
#include <c10/util/irange.h>
#include <c10/util/thread_name.h>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
@ -5249,6 +5250,51 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
avoidRecordStreams);
}
// Create a memory allocator for NCCL. This allocator is used to allocate memory
// that supports NVLink Sharp functionality. This allocator is later pybinded to
// python, so that users can use it to create MemPool. For example:
// >>> pool = torch.cuda.MemPool(backend.mem_allocator)
// Allocate function
static void* _ncclMemAlloc(size_t size, int device, void* stream) {
#ifndef NCCL_HAS_MEM_ALLOC
TORCH_CHECK(
false, "NCCL mem allocator is not supported in this NCCL version");
#else
LOG(INFO) << "NCCL mem allocator: allocating " << size << " bytes";
at::cuda::OptionalCUDAGuard gpuGuard(device);
void* ptr = nullptr;
TORCH_CHECK(ncclMemAlloc(&ptr, size) == ncclSuccess, "ncclMemAlloc failed");
return ptr;
#endif // NCCL_HAS_MEM_ALLOC
}
// Free function
static void _ncclMemFree(void* ptr, size_t size, int device, void* stream) {
#ifndef NCCL_HAS_MEM_ALLOC
TORCH_CHECK(
false, "NCCL mem allocator is not supported in this NCCL version");
#else
LOG(INFO) << "NCCL mem allocator: freeing " << size << " bytes";
at::cuda::OptionalCUDAGuard gpuGuard(device);
TORCH_CHECK(ncclMemFree(ptr) == ncclSuccess, "ncclMemFree failed");
#endif // NCCL_HAS_MEM_ALLOC
}
// Create a `CUDAPluggableAllocator` that uses the above functions.
std::shared_ptr<c10::Allocator> ProcessGroupNCCL::getMemAllocator() {
#ifndef NCCL_HAS_MEM_ALLOC
TORCH_CHECK(
false, "NCCL mem allocator is not supported in this NCCL version");
#endif // NCCL_HAS_MEM_ALLOC
C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.getMemAllocator");
static std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
ncclMemAllocator =
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
_ncclMemAlloc, _ncclMemFree);
return ncclMemAllocator;
}
} // namespace c10d
#endif // USE_C10D_NCCL

View File

@ -768,6 +768,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
ErrorType getError() override;
std::shared_ptr<c10::Allocator> getMemAllocator() override;
// Performs NCCL user buffer registration for all buffers in
// the given MemPool
void registerMemPool(c10::cuda::MemPool* pool);

View File

@ -2765,7 +2765,9 @@ Arguments:
.def(
"_end_coalescing",
&::c10d::Backend::endCoalescing,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"mem_allocator", &::c10d::Backend::getMemAllocator);
// base Backend::Options binding
// TODO: Maybe we can consider how to merge this with