fix retaining multimem in symmetric memory (#160343)

fixes OOM in #160289

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343
Approved by: https://github.com/eqy
This commit is contained in:
Natalia Gimelshein
2025-08-12 02:03:15 +00:00
committed by PyTorch MergeBot
parent 95210cc409
commit be53f609aa
3 changed files with 17 additions and 4 deletions

View File

@ -53,7 +53,8 @@
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030)
_(cuMulticastCreate, 12030) \
_(cuMulticastUnbind, 12030)
#else
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_)
#endif

View File

@ -46,11 +46,13 @@ AllocationRef::AllocationRef(
void* ptr,
HandleType handle,
size_t block_size,
int device_idx)
int device_idx,
bool is_multicast)
: ptr(ptr),
handle(handle),
block_size(block_size),
device_idx(device_idx) {}
device_idx(device_idx),
is_multicast(is_multicast) {}
AllocationRef::~AllocationRef() {
if (is_finalizing()) {
@ -63,6 +65,10 @@ AllocationRef::~AllocationRef() {
auto driver_api = c10::cuda::DriverAPI::get();
C10_CUDA_DRIVER_CHECK(
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
if (is_multicast) {
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size));
}
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
@ -797,6 +803,10 @@ c10::intrusive_ptr<CUDASymmetricMemory> make_symm_mem(
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
alloc_refs.emplace_back(block->alloc_ref);
if (mc_addr != nullptr) {
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(
mc_addr, mc_handle, block->block_size, block->device_idx, true));
}
continue;
}
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(

View File

@ -15,12 +15,14 @@ struct AllocationRef : public c10::intrusive_ptr_target {
HandleType handle;
size_t block_size;
int device_idx;
bool is_multicast;
AllocationRef(
void* ptr,
HandleType handle,
size_t block_size,
int device_idx);
int device_idx,
bool is_multicast = false);
~AllocationRef();
};