mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Uses MemPoolContext to route allocations from CUDACachingAllocator (#134685)
Re-open of https://github.com/pytorch/pytorch/pull/133599 that was mistakenly closed by issuing `ghstack land` Pull Request resolved: https://github.com/pytorch/pytorch/pull/134685 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b4ba7ab06
commit
4655eb3ee2
@ -2674,7 +2674,12 @@ class DeviceCachingAllocator {
|
||||
// any potential exceptions in the cudaMallocMaybeCapturing function.
|
||||
auto sg = c10::make_scope_exit([&]() { lock.lock(); });
|
||||
lock.unlock();
|
||||
p.err = cudaMallocMaybeCapturing(&ptr, size);
|
||||
}
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
if (active_pool && active_pool->allocator() &&
|
||||
p.pool->owner_PrivatePool) {
|
||||
ptr = active_pool->allocator()->raw_alloc(size);
|
||||
p.err = ptr ? cudaSuccess : cudaErrorMemoryAllocation;
|
||||
} else {
|
||||
p.err = cudaMallocMaybeCapturing(&ptr, size);
|
||||
}
|
||||
|
@ -122,6 +122,9 @@ Memory management
|
||||
change_current_allocator
|
||||
MemPool
|
||||
MemPoolContext
|
||||
|
||||
.. autoclass:: torch.cuda.use_mem_pool
|
||||
|
||||
.. FIXME The following doesn't seem to exist. Is it supposed to?
|
||||
https://github.com/pytorch/pytorch/issues/27785
|
||||
.. autofunction:: reset_max_memory_reserved
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import ctypes
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
@ -4806,10 +4807,25 @@ class TestMemPool(TestCase):
|
||||
|
||||
dummy_allocator_source = """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
extern "C" {
|
||||
C10_EXPORT int called_dummy_alloc = 0;
|
||||
C10_EXPORT int called_dummy_free = 0;
|
||||
|
||||
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
||||
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { return nullptr; }
|
||||
C10_EXPORT void dummy_free(void* ptr) { }
|
||||
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
||||
called_dummy_alloc = 123;
|
||||
void* ptr;
|
||||
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
||||
called_dummy_free = 321;
|
||||
C10_CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
}
|
||||
"""
|
||||
dummy_allocator_libname = "dummy_allocator"
|
||||
@ -4819,6 +4835,7 @@ class TestMemPool(TestCase):
|
||||
is_python_module=False,
|
||||
keep_intermediates=False,
|
||||
verbose=True,
|
||||
with_cuda=True,
|
||||
)
|
||||
allocator = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
dummy_allocator,
|
||||
@ -4830,6 +4847,18 @@ class TestMemPool(TestCase):
|
||||
# pool should point to the same allocator as the one passed into it
|
||||
self.assertEqual(allocator.allocator(), pool.allocator)
|
||||
|
||||
# no allocations happened yet, so called_dummy_alloc should be 0
|
||||
alloc_lib = ctypes.CDLL(dummy_allocator)
|
||||
called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
|
||||
self.assertEqual(called_dummy_alloc.value, 0)
|
||||
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
out = torch.randn(1, device="cuda")
|
||||
|
||||
# called_dummy_alloc should be 123 if dummy_alloc was used to allocate
|
||||
# out tensor
|
||||
self.assertEqual(called_dummy_alloc.value, 123)
|
||||
|
||||
def test_mempool_context(self):
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
|
@ -1831,6 +1831,7 @@ def _cuda_cudaHostAllocator() -> _int: ...
|
||||
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
|
||||
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
|
||||
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
|
||||
def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
|
||||
|
@ -1281,6 +1281,13 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||
});
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_cuda_beginAllocateToPool",
|
||||
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
device, mempool_id, [](cudaStream_t) { return true; });
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_cuda_endAllocateCurrentStreamToPool",
|
||||
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
|
||||
|
@ -1628,6 +1628,7 @@ __all__ = [
|
||||
"memory_usage",
|
||||
"MemPool",
|
||||
"MemPoolContext",
|
||||
"use_mem_pool",
|
||||
"temperature",
|
||||
"power_draw",
|
||||
"clock_rate",
|
||||
|
@ -52,6 +52,7 @@ __all__ = [
|
||||
"change_current_allocator",
|
||||
"MemPool",
|
||||
"MemPoolContext",
|
||||
"use_mem_pool",
|
||||
]
|
||||
|
||||
|
||||
@ -64,8 +65,20 @@ if not hasattr(torch._C, "_MemPool"):
|
||||
# Define dummy base classes
|
||||
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
|
||||
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
|
||||
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
||||
"_cuda_beginAllocateToPool"
|
||||
)
|
||||
torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type(
|
||||
"_cuda_endAllocateCurrentStreamToPool"
|
||||
)
|
||||
|
||||
from torch._C import _cuda_CUDAAllocator, _MemPool, _MemPoolContext # noqa: F401
|
||||
from torch._C import ( # noqa: F401
|
||||
_cuda_beginAllocateToPool,
|
||||
_cuda_CUDAAllocator,
|
||||
_cuda_endAllocateCurrentStreamToPool,
|
||||
_MemPool,
|
||||
_MemPoolContext,
|
||||
)
|
||||
|
||||
|
||||
def _host_allocator():
|
||||
@ -1002,3 +1015,27 @@ class MemPoolContext(_MemPoolContext):
|
||||
def active_pool() -> Optional[_MemPool]:
|
||||
r"""Returns the active MemPool"""
|
||||
return _MemPoolContext.active_pool()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_mem_pool(pool: MemPool, device: Union[Device, int] = None):
|
||||
r"""A context manager that routes allocations to a given pool.
|
||||
|
||||
Args:
|
||||
pool(torch.cuda.MemPool): a MemPool object to be made active so that
|
||||
allocations route to this pool.
|
||||
device (torch.device or int, optional): selected device. Uses MemPool on
|
||||
the current device, given by :func:`~torch.cuda.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
|
||||
"""
|
||||
ctx = MemPoolContext(pool)
|
||||
device_index = (
|
||||
torch.cuda.current_device() if device is None else _get_device_index(device)
|
||||
)
|
||||
_cuda_beginAllocateToPool(device_index, pool.id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_cuda_endAllocateCurrentStreamToPool(device_index, pool.id)
|
||||
del ctx
|
||||
|
Reference in New Issue
Block a user