mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option to use mempool on OOM (#151487)
MemPool is a separate pool of memory handled by the caching allocator. This PR adds the option let the caching allocator try to use this pool as a last resort instead of OOMing by associating a use_on_oom bool with each MemPool.
Usage:
Users can optionally specify a ``use_on_oom`` bool (which is False by default) during MemPool creation. If true, then the CUDACachingAllocator will be able to use memory in this pool as a last resort instead of OOMing.
```
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
with torch.cuda.use_mem_pool(pool):
a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
del a
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
```
Testing:
```
python test/test_cuda.py -k test_mempool_limited_memory_with_allocator
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151487
Approved by: https://github.com/eqy, https://github.com/syed-ahmed, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
65b845f82b
commit
d22c4cc353
@ -4920,6 +4920,25 @@ class TestBlockStateAbsorption(TestCase):
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
||||
class TestMemPool(TestCase):
|
||||
def _setup_mempool_limited_memory_test(self, additional_allowed_memory_in_mb):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self.init_fraction = torch.cuda.get_per_process_memory_fraction()
|
||||
torch.cuda.memory.empty_cache()
|
||||
mb = 1024 * 1024
|
||||
_, all_memory = torch.cuda.memory.mem_get_info(device)
|
||||
pre_reserved = torch.cuda.memory_reserved(device)
|
||||
total_allowed = additional_allowed_memory_in_mb * mb + pre_reserved
|
||||
fraction_allowed = total_allowed / all_memory
|
||||
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed, device)
|
||||
|
||||
dtype = torch.int8
|
||||
return device, dtype
|
||||
|
||||
def _teardown_mempool_limited_memory_test(self):
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory.set_per_process_memory_fraction(self.init_fraction)
|
||||
|
||||
def test_mempool_id(self):
|
||||
pool1 = torch.cuda.graph_pool_handle()
|
||||
pool2 = torch.cuda.MemPool().id
|
||||
@ -5036,6 +5055,110 @@ class TestMemPool(TestCase):
|
||||
# out tensor
|
||||
self.assertEqual(called_dummy_free.value, 321)
|
||||
|
||||
@serialTest()
|
||||
def test_mempool_limited_memory_with_allocator(self):
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
dummy_allocator_source = """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
extern "C" {
|
||||
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
||||
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
||||
void* ptr;
|
||||
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
||||
C10_CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
}
|
||||
"""
|
||||
dummy_allocator_libname = "dummy_allocator"
|
||||
dummy_allocator = load_inline(
|
||||
name=dummy_allocator_libname,
|
||||
cpp_sources=dummy_allocator_source,
|
||||
is_python_module=False,
|
||||
keep_intermediates=False,
|
||||
verbose=True,
|
||||
with_cuda=True,
|
||||
)
|
||||
allocator = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
dummy_allocator,
|
||||
"dummy_alloc",
|
||||
"dummy_free",
|
||||
)
|
||||
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
|
||||
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
|
||||
|
||||
nelem_1mb = 1024 * 1024 // 4
|
||||
|
||||
self._setup_mempool_limited_memory_test(80)
|
||||
# remaining free mem: 80 mb
|
||||
# mempool_use [] 0 mb
|
||||
# mempool_do_not_use [] 0 mb
|
||||
# default pool [] 0 mb
|
||||
with torch.cuda.use_mem_pool(pool_do_not_use):
|
||||
a = torch.randn(40 * nelem_1mb, device="cuda")
|
||||
with torch.cuda.use_mem_pool(pool_use):
|
||||
b = torch.randn(40 * nelem_1mb, device="cuda")
|
||||
a_dataptr = a.data_ptr()
|
||||
b_dataptr = b.data_ptr()
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [aaaa] 40 mb
|
||||
# mempool_use [bbbb] 40 mb
|
||||
# default pool [] 0 mb
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
# out of memory
|
||||
c = torch.randn(40 * nelem_1mb, device="cuda")
|
||||
|
||||
del a, b
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [____] 40 mb
|
||||
# default pool [] 0 mb
|
||||
|
||||
# c should not oom and instead can use mempool_use as fallback
|
||||
c = torch.randn(30 * nelem_1mb, device="cuda")
|
||||
c_dataptr = c.data_ptr()
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [ccc_] 40 mb
|
||||
# default pool [] 0 mb
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
# out of memory since can't use mempool_do_not_use
|
||||
d = torch.randn(30 * nelem_1mb, device="cuda")
|
||||
|
||||
del c
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [____] 40 mb
|
||||
# default pool [] 0 mb
|
||||
|
||||
# expect that we used same memory address for both a and c
|
||||
self.assertEqual(b_dataptr, c_dataptr)
|
||||
|
||||
# make sure we can still use mempool_use as intended after c is deleted
|
||||
with torch.cuda.use_mem_pool(pool_use):
|
||||
e = torch.randn(20 * nelem_1mb, device="cuda")
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [ee__] 40 mb
|
||||
# default pool [] 0 mb
|
||||
|
||||
e_dataptr = e.data_ptr()
|
||||
del e
|
||||
|
||||
self.assertEqual(e_dataptr, c_dataptr)
|
||||
|
||||
# pool's destructor calls emptyCache()
|
||||
del pool_use, pool_do_not_use
|
||||
|
||||
self._teardown_mempool_limited_memory_test()
|
||||
|
||||
def test_mempool_context(self):
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user