mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix retaining multimem in symmetric memory (#160343)
fixes OOM in #160289 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
95210cc409
commit
be53f609aa
@ -53,7 +53,8 @@
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||
_(cuMulticastAddDevice, 12030) \
|
||||
_(cuMulticastBindMem, 12030) \
|
||||
_(cuMulticastCreate, 12030)
|
||||
_(cuMulticastCreate, 12030) \
|
||||
_(cuMulticastUnbind, 12030)
|
||||
#else
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_)
|
||||
#endif
|
||||
|
@ -46,11 +46,13 @@ AllocationRef::AllocationRef(
|
||||
void* ptr,
|
||||
HandleType handle,
|
||||
size_t block_size,
|
||||
int device_idx)
|
||||
int device_idx,
|
||||
bool is_multicast)
|
||||
: ptr(ptr),
|
||||
handle(handle),
|
||||
block_size(block_size),
|
||||
device_idx(device_idx) {}
|
||||
device_idx(device_idx),
|
||||
is_multicast(is_multicast) {}
|
||||
|
||||
AllocationRef::~AllocationRef() {
|
||||
if (is_finalizing()) {
|
||||
@ -63,6 +65,10 @@ AllocationRef::~AllocationRef() {
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
|
||||
if (is_multicast) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size));
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
|
||||
#elif defined(USE_ROCM)
|
||||
C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
|
||||
@ -797,6 +803,10 @@ c10::intrusive_ptr<CUDASymmetricMemory> make_symm_mem(
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
alloc_refs.emplace_back(block->alloc_ref);
|
||||
if (mc_addr != nullptr) {
|
||||
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(
|
||||
mc_addr, mc_handle, block->block_size, block->device_idx, true));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(
|
||||
|
@ -15,12 +15,14 @@ struct AllocationRef : public c10::intrusive_ptr_target {
|
||||
HandleType handle;
|
||||
size_t block_size;
|
||||
int device_idx;
|
||||
bool is_multicast;
|
||||
|
||||
AllocationRef(
|
||||
void* ptr,
|
||||
HandleType handle,
|
||||
size_t block_size,
|
||||
int device_idx);
|
||||
int device_idx,
|
||||
bool is_multicast = false);
|
||||
|
||||
~AllocationRef();
|
||||
};
|
||||
|
Reference in New Issue
Block a user