mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Add NCCL memory allocator (#145675)
This PR implements a small UI improvement over #133603. It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it. UI: ``` pool = torch.cuda.MemPool(backend.mem_allocator) with torch.cuda.use_mem_pool(pool): tensor = torch.arange(1024 * 1024 * 2, device=device) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145675 Approved by: https://github.com/syed-ahmed, https://github.com/wconstab
This commit is contained in:
@ -67,7 +67,6 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
@ -3104,40 +3103,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
|
||||
|
||||
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
def createNcclAllocator(self):
|
||||
nccl_allocator_source = """
|
||||
#include <torch/extension.h>
|
||||
#include <nccl.h>
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
||||
C10_EXPORT void* nccl_alloc(size_t size, int device, void* stream) {
|
||||
std::cout << "Using ncclMemAlloc" << std::endl;
|
||||
void* ptr;
|
||||
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
C10_EXPORT void nccl_free(void* ptr, size_t size, int device, void* stream) {
|
||||
std::cout << "Using ncclMemFree" << std::endl;
|
||||
ncclResult_t err = ncclMemFree(ptr);
|
||||
}
|
||||
}
|
||||
"""
|
||||
nccl_allocator_libname = "nccl_allocator"
|
||||
nccl_allocator = load_inline(
|
||||
name=nccl_allocator_libname,
|
||||
cpp_sources=nccl_allocator_source,
|
||||
with_cuda=True,
|
||||
extra_ldflags=["-lnccl"],
|
||||
is_python_module=False,
|
||||
keep_intermediates=False,
|
||||
verbose=True,
|
||||
)
|
||||
return nccl_allocator
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
@ -3172,13 +3137,9 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
torch.cuda.set_device(self.rank)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
allocator_path = self.createNcclAllocator()
|
||||
allocator = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
allocator_path,
|
||||
"nccl_alloc",
|
||||
"nccl_free",
|
||||
)
|
||||
pool = torch.cuda.MemPool(allocator.allocator())
|
||||
|
||||
# Use NCCL memory allocator
|
||||
pool = torch.cuda.MemPool(backend.mem_allocator)
|
||||
|
||||
# allocate memory with ncclMemAlloc
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
|
@ -296,6 +296,8 @@ class Backend:
|
||||
def _set_sequence_number_for_group(self) -> None: ...
|
||||
def _set_default_timeout(self, timeout: timedelta) -> None: ...
|
||||
def get_error(self) -> ErrorType: ...
|
||||
@property
|
||||
def mem_allocator(self) -> Any: ...
|
||||
|
||||
class ProcessGroup:
|
||||
class BackendType(Enum):
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Types.hpp>
|
||||
@ -409,6 +410,13 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
c10::str("Backend ", getBackendName(), " does not support getError"));
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<c10::Allocator> getMemAllocator() {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str(
|
||||
"Backend ", getBackendName(), " does not support getMemAllocator"));
|
||||
}
|
||||
|
||||
protected:
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
|
@ -86,6 +86,10 @@ static_assert(
|
||||
#define NCCL_HAS_COMM_REGISTER
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0)
|
||||
#define NCCL_HAS_MEM_ALLOC
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include <c10/util/WaitCounter.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/thread_name.h>
|
||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||
#include <torch/csrc/cuda/nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
@ -5249,6 +5250,51 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
|
||||
avoidRecordStreams);
|
||||
}
|
||||
|
||||
// Create a memory allocator for NCCL. This allocator is used to allocate memory
|
||||
// that supports NVLink Sharp functionality. This allocator is later pybinded to
|
||||
// python, so that users can use it to create MemPool. For example:
|
||||
// >>> pool = torch.cuda.MemPool(backend.mem_allocator)
|
||||
|
||||
// Allocate function
|
||||
static void* _ncclMemAlloc(size_t size, int device, void* stream) {
|
||||
#ifndef NCCL_HAS_MEM_ALLOC
|
||||
TORCH_CHECK(
|
||||
false, "NCCL mem allocator is not supported in this NCCL version");
|
||||
#else
|
||||
LOG(INFO) << "NCCL mem allocator: allocating " << size << " bytes";
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
void* ptr = nullptr;
|
||||
TORCH_CHECK(ncclMemAlloc(&ptr, size) == ncclSuccess, "ncclMemAlloc failed");
|
||||
return ptr;
|
||||
#endif // NCCL_HAS_MEM_ALLOC
|
||||
}
|
||||
|
||||
// Free function
|
||||
static void _ncclMemFree(void* ptr, size_t size, int device, void* stream) {
|
||||
#ifndef NCCL_HAS_MEM_ALLOC
|
||||
TORCH_CHECK(
|
||||
false, "NCCL mem allocator is not supported in this NCCL version");
|
||||
#else
|
||||
LOG(INFO) << "NCCL mem allocator: freeing " << size << " bytes";
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
TORCH_CHECK(ncclMemFree(ptr) == ncclSuccess, "ncclMemFree failed");
|
||||
#endif // NCCL_HAS_MEM_ALLOC
|
||||
}
|
||||
|
||||
// Create a `CUDAPluggableAllocator` that uses the above functions.
|
||||
std::shared_ptr<c10::Allocator> ProcessGroupNCCL::getMemAllocator() {
|
||||
#ifndef NCCL_HAS_MEM_ALLOC
|
||||
TORCH_CHECK(
|
||||
false, "NCCL mem allocator is not supported in this NCCL version");
|
||||
#endif // NCCL_HAS_MEM_ALLOC
|
||||
C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.getMemAllocator");
|
||||
static std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
|
||||
ncclMemAllocator =
|
||||
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
|
||||
_ncclMemAlloc, _ncclMemFree);
|
||||
return ncclMemAllocator;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -768,6 +768,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
ErrorType getError() override;
|
||||
|
||||
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
||||
|
||||
// Performs NCCL user buffer registration for all buffers in
|
||||
// the given MemPool
|
||||
void registerMemPool(c10::cuda::MemPool* pool);
|
||||
|
@ -2765,7 +2765,9 @@ Arguments:
|
||||
.def(
|
||||
"_end_coalescing",
|
||||
&::c10d::Backend::endCoalescing,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly(
|
||||
"mem_allocator", &::c10d::Backend::getMemAllocator);
|
||||
|
||||
// base Backend::Options binding
|
||||
// TODO: Maybe we can consider how to merge this with
|
||||
|
Reference in New Issue
Block a user