Enables NCCL symmetric memory kernels through mempool registration (#155134)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155134
Approved by: https://github.com/kwen2501

Co-authored-by: Ke Wen <kw2501@meta.com>
This commit is contained in:
Syed Tousif Ahmed
2025-06-21 13:08:13 -07:00
committed by PyTorch MergeBot
parent 9e132b770e
commit f70c80105e
9 changed files with 146 additions and 15 deletions

View File

@ -4205,8 +4205,11 @@ std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
bool use_on_oom,
bool symmetric)
: allocator_(allocator),
is_user_created_(is_user_created),
symmetric_(symmetric) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
@ -4229,6 +4232,10 @@ MempoolId_t MemPool::id() {
return id_;
}
bool MemPool::is_symmetric() {
return symmetric_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}

View File

@ -535,7 +535,8 @@ struct C10_CUDA_API MemPool {
MemPool(
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false);
bool use_on_oom = false,
bool symmetric = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
@ -543,6 +544,7 @@ struct C10_CUDA_API MemPool {
~MemPool();
MempoolId_t id();
bool is_symmetric();
CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
@ -554,6 +556,7 @@ struct C10_CUDA_API MemPool {
CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
bool symmetric_;
c10::DeviceIndex device_;
};

View File

@ -3084,7 +3084,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("4294967295")
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
class NcclRegistrationTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
@ -3095,7 +3095,7 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase):
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS"
if torch.cuda.nccl.version() >= (2, 24, 3):
os.environ["NCCL_DEBUG_SUBSYS"] = "REG"
os.environ["NCCL_DEBUG_SUBSYS"] = "REG,TUNING"
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
self._spawn_processes()
@ -3151,6 +3151,48 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase):
else:
self.assertRegex(nccl_debug_file_content, "local-registered")
@requires_nccl()
@requires_nccl_version((2, 27), "Need NCCL 2.27 for window registration")
@skip_if_lt_x_gpu(4)
@requires_multicast_support()
def test_nccl_window_registration(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
device = torch.device(f"cuda:{self.rank}")
torch.cuda.set_device(self.rank)
pg = c10d.distributed_c10d._get_default_group()
backend = pg._get_backend(torch.device(device))
# Use NCCL memory allocator
# enable symmetric memory usage in NCCL
pool = torch.cuda.MemPool(backend.mem_allocator, symm_mem=True)
# allocate memory with ncclMemAlloc
# note: symmetric kernels are not available for dtypes like torch.int64
with torch.cuda.use_mem_pool(pool):
tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32)
# register buffers to NCCL
backend.register_mem_pool(pool)
# allreduce now should use NVIDIA Switches
pg.allreduce(tensor).wait()
torch.cuda.synchronize(device=device)
# de-register buffers from NCCL
backend.deregister_mem_pool(pool)
# clean up memory
del tensor, pool
with open(os.environ["NCCL_DEBUG_FILE"]) as f:
nccl_debug_file_content = f.read()
# if buffers were registered and symmetric kernels ran, NCCL_DEBUG
# should show successful registration in debug output
self.assertRegex(nccl_debug_file_content, "[Symmetric]")
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
@property

View File

@ -2300,10 +2300,13 @@ class _MemPool:
allocator: _cuda_CUDAAllocator | None = None,
is_user_created: _bool = True,
use_on_oom: _bool = False,
symmetric: _bool = False,
) -> None: ...
@property
def id(self) -> tuple[_int, _int]: ...
@property
def is_symmetric(self) -> _bool: ...
@property
def allocator(self) -> _cuda_CUDAAllocator | None: ...
def use_count(self) -> _int: ...

View File

@ -16,12 +16,15 @@ void THCPMemPool_init(PyObject* module) {
.def(
py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom) {
bool use_on_oom,
bool symmetric) {
torch::utils::device_lazy_init(at::kCUDA);
return std::make_shared<::c10::cuda::MemPool>(
allocator, is_user_created, use_on_oom);
allocator, is_user_created, use_on_oom, symmetric);
}))
.def_property_readonly("id", &::c10::cuda::MemPool::id)
.def_property_readonly(
"is_symmetric", &::c10::cuda::MemPool::is_symmetric)
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
.def("use_count", &::c10::cuda::MemPool::use_count);
}

View File

@ -350,7 +350,8 @@ ncclResult_t NCCLComm::checkForNcclError() {
ncclResult_t NCCLComm::registerSegment(
void* ptr,
size_t size,
bool errorOnRereg /*=true*/) {
bool errorOnRereg, /*=true*/
bool window /*=false*/) {
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
// We register only segments from cache allocator
@ -371,6 +372,30 @@ ncclResult_t NCCLComm::registerSegment(
void* handle = nullptr;
// Use getNcclComm to make sure comm is ready before calling nccl APIs
auto comm = getNcclComm();
#ifdef NCCL_HAS_COMM_WINDOW_REGISTER
if (window) {
C10D_NCCL_CHECK(
ncclCommWindowRegister(
comm, ptr, size, (ncclWindow_t*)&handle, NCCL_WIN_COLL_SYMMETRIC),
c10::str(
"Failed to window register segment with ptr ",
ptr,
", size ",
size,
" on ncclComm_ ",
comm));
} else {
C10D_NCCL_CHECK(
ncclCommRegister(comm, ptr, size, &handle),
c10::str(
"Failed to register segment with ptr ",
ptr,
", size ",
size,
" on ncclComm_ ",
comm));
}
#else
C10D_NCCL_CHECK(
ncclCommRegister(comm, ptr, size, &handle),
c10::str(
@ -380,6 +405,7 @@ ncclResult_t NCCLComm::registerSegment(
size,
" on ncclComm_ ",
comm));
#endif
registeredSegmentHandles_[ptr] = handle;
return ncclSuccess;
#else
@ -387,7 +413,7 @@ ncclResult_t NCCLComm::registerSegment(
#endif
}
ncclResult_t NCCLComm::deregisterSegment(void* ptr) {
ncclResult_t NCCLComm::deregisterSegment(void* ptr, bool window /*false*/) {
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
TORCH_CHECK(
@ -400,6 +426,29 @@ ncclResult_t NCCLComm::deregisterSegment(void* ptr) {
void* handle = registeredSegmentHandles_[ptr];
// Use getNcclComm to make sure comm is ready before calling nccl APIs
auto comm = getNcclComm();
#ifdef NCCL_HAS_COMM_WINDOW_REGISTER
if (window) {
C10D_NCCL_CHECK(
ncclCommWindowDeregister(comm, (ncclWindow_t)handle),
c10::str(
"Failed to window deregister segment handle ",
handle,
", with ptr ",
ptr,
" on ncclComm_ ",
comm));
} else {
C10D_NCCL_CHECK(
ncclCommDeregister(comm, handle),
c10::str(
"Failed to deregister segment handle ",
handle,
", with ptr ",
ptr,
" on ncclComm_ ",
comm));
}
#else
C10D_NCCL_CHECK(
ncclCommDeregister(comm, handle),
c10::str(
@ -409,6 +458,7 @@ ncclResult_t NCCLComm::deregisterSegment(void* ptr) {
ptr,
" on ncclComm_ ",
comm));
#endif
registeredSegmentHandles_.erase(ptr);
return ncclSuccess;
#else

View File

@ -63,6 +63,10 @@ static_assert(
#define NCCL_HAS_COMM_REGISTER
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
#define NCCL_HAS_COMM_WINDOW_REGISTER
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0)
#define NCCL_HAS_MEM_ALLOC
#endif
@ -341,9 +345,10 @@ class NCCLComm {
ncclResult_t registerSegment(
void* ptr,
size_t size,
bool errorOnRereg = true);
bool errorOnRereg = true,
bool window = false);
ncclResult_t deregisterSegment(void* ptr);
ncclResult_t deregisterSegment(void* ptr, bool window = false);
std::string repr() const;

View File

@ -1182,6 +1182,14 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
// register future segments allocated in this pool (this call is idempotent).
attachAllocatorHooks();
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
// TODO:
// if(pool->is_symmetric()) {
// Allgather to verify len(mempool.snapshot.segments) matches across GPUs
// Allgather to verify mempool.alloc_request_counter matches across GPUs
// add alloc_request_counter per mempool (How many allocations a mempool has
// served during its lifetime) this should guarantee pool is used in a
// symmetric/SPMD manner
// }
for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(),
@ -1190,7 +1198,9 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<void*>(segmentInfo.address),
segmentInfo.total_size,
/*errorOnRereg=*/false); // ignores reregistration error
/*errorOnRereg=*/false, // ignores reregistration error
/*window=*/pool->is_symmetric()); // whether to use NCCL symmetric
// memory
}
}
@ -1221,7 +1231,8 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
segmentInfo.device == pool->device(),
"Mismatch between CUDA memory segment device and pool's device");
// NOLINTNEXTLINE(performance-no-int-to-ptr)
ncclComm->deregisterSegment(reinterpret_cast<void*>(segmentInfo.address));
ncclComm->deregisterSegment(
reinterpret_cast<void*>(segmentInfo.address), pool->is_symmetric());
}
}

View File

@ -1163,21 +1163,28 @@ class MemPool(_MemPool):
use_on_oom(bool): a bool that indicates if this pool can be used
as a last resort if a memory allocation outside of the pool fails due
to Out Of Memory. This is False by default.
symmetric(bool): a bool that indicates if this pool is symmetrical
across ranks. This is False by default.
"""
def __init__(
self,
allocator: Optional[_cuda_CUDAAllocator] = None,
use_on_oom: bool = False,
symmetric: bool = False,
):
super().__init__(allocator, True, use_on_oom)
super().__init__(allocator, True, use_on_oom, symmetric)
@property
def id(self) -> tuple[int, int]:
r"""Returns the ID of this pool as a tuple of two ints."""
return super().id
@property
def is_symmetric(self) -> bool:
r"""Returns whether this pool is used for NCCL's symmetric memory."""
return super().is_symmetric
@property
def allocator(self) -> Optional[_cuda_CUDAAllocator]:
r"""Returns the allocator this MemPool routes allocations to."""