From be53f609aaf6f01e2863f490975ea9eaac3ee9ff Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 12 Aug 2025 02:03:15 +0000 Subject: [PATCH] fix retaining multimem in symmetric memory (#160343) fixes OOM in #160289 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343 Approved by: https://github.com/eqy --- c10/cuda/driver_api.h | 3 ++- .../c10d/symm_mem/CUDASymmetricMemory.cu | 14 ++++++++++++-- .../c10d/symm_mem/CUDASymmetricMemory.hpp | 4 +++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 9800809d1e53..6702cb9b532d 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -53,7 +53,8 @@ #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ - _(cuMulticastCreate, 12030) + _(cuMulticastCreate, 12030) \ + _(cuMulticastUnbind, 12030) #else #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #endif diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index e9fc7aefaf57..b2f216335bb1 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -46,11 +46,13 @@ AllocationRef::AllocationRef( void* ptr, HandleType handle, size_t block_size, - int device_idx) + int device_idx, + bool is_multicast) : ptr(ptr), handle(handle), block_size(block_size), - device_idx(device_idx) {} + device_idx(device_idx), + is_multicast(is_multicast) {} AllocationRef::~AllocationRef() { if (is_finalizing()) { @@ -63,6 +65,10 @@ AllocationRef::~AllocationRef() { auto driver_api = c10::cuda::DriverAPI::get(); C10_CUDA_DRIVER_CHECK( driver_api->cuMemUnmap_(reinterpret_cast(ptr), block_size)); + if (is_multicast) { + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size)); + } C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle)); #elif defined(USE_ROCM) C10_HIP_CHECK(hipMemUnmap(reinterpret_cast(ptr), block_size)); @@ -797,6 +803,10 @@ c10::intrusive_ptr make_symm_mem( for (int r = 0; r < world_size; ++r) { if (r == rank) { alloc_refs.emplace_back(block->alloc_ref); + if (mc_addr != nullptr) { + alloc_refs.push_back(c10::make_intrusive( + mc_addr, mc_handle, block->block_size, block->device_idx, true)); + } continue; } alloc_refs.push_back(c10::make_intrusive( diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index a5340ffc9806..f61d8f9622a7 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -15,12 +15,14 @@ struct AllocationRef : public c10::intrusive_ptr_target { HandleType handle; size_t block_size; int device_idx; + bool is_multicast; AllocationRef( void* ptr, HandleType handle, size_t block_size, - int device_idx); + int device_idx, + bool is_multicast = false); ~AllocationRef(); };