Add option to use mempool on OOM (#151487)

MemPool is a separate pool of memory handled by the caching allocator. This PR adds the option let the caching allocator try to use this pool as a last resort instead of OOMing by associating a use_on_oom bool with each MemPool.

Usage:
Users can optionally specify a ``use_on_oom`` bool (which is False by default) during MemPool creation. If true, then the CUDACachingAllocator will be able to use memory in this pool as a last resort instead of OOMing.

```
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
with torch.cuda.use_mem_pool(pool):
    a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
del a
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
```

Testing:
```
python test/test_cuda.py -k test_mempool_limited_memory_with_allocator
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151487
Approved by: https://github.com/eqy, https://github.com/syed-ahmed, https://github.com/ngimel
This commit is contained in:
Dan Johnson
2025-04-25 17:52:27 -07:00
committed by PyTorch MergeBot
parent 65b845f82b
commit d22c4cc353
7 changed files with 217 additions and 7 deletions

View File

@ -4920,6 +4920,25 @@ class TestBlockStateAbsorption(TestCase):
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
class TestMemPool(TestCase):
def _setup_mempool_limited_memory_test(self, additional_allowed_memory_in_mb):
device = torch.device("cuda:0")
self.init_fraction = torch.cuda.get_per_process_memory_fraction()
torch.cuda.memory.empty_cache()
mb = 1024 * 1024
_, all_memory = torch.cuda.memory.mem_get_info(device)
pre_reserved = torch.cuda.memory_reserved(device)
total_allowed = additional_allowed_memory_in_mb * mb + pre_reserved
fraction_allowed = total_allowed / all_memory
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed, device)
dtype = torch.int8
return device, dtype
def _teardown_mempool_limited_memory_test(self):
torch.cuda.memory.empty_cache()
torch.cuda.memory.set_per_process_memory_fraction(self.init_fraction)
def test_mempool_id(self):
pool1 = torch.cuda.graph_pool_handle()
pool2 = torch.cuda.MemPool().id
@ -5036,6 +5055,110 @@ class TestMemPool(TestCase):
# out tensor
self.assertEqual(called_dummy_free.value, 321)
@serialTest()
def test_mempool_limited_memory_with_allocator(self):
from torch.utils.cpp_extension import load_inline
dummy_allocator_source = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_libname = "dummy_allocator"
dummy_allocator = load_inline(
name=dummy_allocator_libname,
cpp_sources=dummy_allocator_source,
is_python_module=False,
keep_intermediates=False,
verbose=True,
with_cuda=True,
)
allocator = torch.cuda.memory.CUDAPluggableAllocator(
dummy_allocator,
"dummy_alloc",
"dummy_free",
)
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
nelem_1mb = 1024 * 1024 // 4
self._setup_mempool_limited_memory_test(80)
# remaining free mem: 80 mb
# mempool_use [] 0 mb
# mempool_do_not_use [] 0 mb
# default pool [] 0 mb
with torch.cuda.use_mem_pool(pool_do_not_use):
a = torch.randn(40 * nelem_1mb, device="cuda")
with torch.cuda.use_mem_pool(pool_use):
b = torch.randn(40 * nelem_1mb, device="cuda")
a_dataptr = a.data_ptr()
b_dataptr = b.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [aaaa] 40 mb
# mempool_use [bbbb] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory
c = torch.randn(40 * nelem_1mb, device="cuda")
del a, b
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb
# c should not oom and instead can use mempool_use as fallback
c = torch.randn(30 * nelem_1mb, device="cuda")
c_dataptr = c.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ccc_] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory since can't use mempool_do_not_use
d = torch.randn(30 * nelem_1mb, device="cuda")
del c
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb
# expect that we used same memory address for both a and c
self.assertEqual(b_dataptr, c_dataptr)
# make sure we can still use mempool_use as intended after c is deleted
with torch.cuda.use_mem_pool(pool_use):
e = torch.randn(20 * nelem_1mb, device="cuda")
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ee__] 40 mb
# default pool [] 0 mb
e_dataptr = e.data_ptr()
del e
self.assertEqual(e_dataptr, c_dataptr)
# pool's destructor calls emptyCache()
del pool_use, pool_do_not_use
self._teardown_mempool_limited_memory_test()
def test_mempool_context(self):
active_pool = torch.cuda.MemPoolContext.active_pool()