Introduce a generic API torch._C._accelerator_setAllocatorSettings (#165291)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165291
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289
This commit is contained in:
Yu, Guangye
2025-10-17 17:16:44 +00:00
committed by PyTorch MergeBot
parent a1114beed2
commit b2f5c25b27
6 changed files with 23 additions and 26 deletions

View File

@ -2048,7 +2048,6 @@ def _cuda_cudaHostAllocator() -> _int: ...
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ...
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
def _cuda_beginAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ...
def _cuda_beginAllocateCurrentThreadToPool(
device: _int,
@ -2477,6 +2476,7 @@ def _accelerator_emptyCache() -> None: ...
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
def _accelerator_setAllocatorSettings(env: str) -> None: ...
# Defined in torch/csrc/jit/python/python_tracer.cpp
class TracingState:

View File

@ -449,6 +449,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._accelerator_getAccelerator",
"torch._C._accelerator_getDeviceIndex",
"torch._C._accelerator_getStream",
"torch._C._accelerator_setAllocatorSettings",
"torch._C._accelerator_setStream",
"torch._C._accelerator_synchronizeDevice",
"torch._C._activate_gpu_trace",
@ -505,7 +506,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._cuda_clearCublasWorkspaces",
"torch._C._cuda_cudaCachingAllocator_raw_alloc",
"torch._C._cuda_cudaCachingAllocator_raw_delete",
"torch._C._cuda_cudaCachingAllocator_set_allocator_settings",
"torch._C._cuda_cudaHostAllocator",
"torch._C._cuda_customAllocator",
"torch._C._cuda_emptyCache",

View File

@ -1,3 +1,4 @@
#include <c10/core/AllocatorConfig.h>
#include <torch/csrc/DeviceAccelerator.h>
#include <torch/csrc/utils/device_lazy_init.h>
@ -136,6 +137,10 @@ void initModule(PyObject* module) {
m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) {
at::accelerator::resetPeakStats(device_index);
});
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
c10::CachingAllocator::setAllocatorSettings(env);
});
}
} // namespace torch::accelerator

View File

@ -20,8 +20,8 @@
#include <ATen/cuda/detail/CUDAHooks.h>
#include <ATen/cuda/jiterator.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <c10/core/AllocatorConfig.h>
#include <c10/core/StorageImpl.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
@ -422,16 +422,6 @@ PyObject* THCPModule_cudaCachingAllocator_enable(
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
PyObject* _unused,
PyObject* env) {
HANDLE_TH_ERRORS
c10::cuda::CUDACachingAllocator::setAllocatorSettings(
THPUtils_unpackString(env));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
return THPUtils_packString(c10::cuda::CUDACachingAllocator::name());
@ -2077,10 +2067,6 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_cudaCachingAllocator_enable,
METH_O,
nullptr},
{"_cuda_cudaCachingAllocator_set_allocator_settings",
THCPModule_cudaCachingAllocator_set_allocator_settings,
METH_O,
nullptr},
{"_cuda_getAllocatorBackend",
THCPModule_getAllocatorBackend,
METH_NOARGS,

View File

@ -1101,8 +1101,12 @@ def _save_memory_usage(filename="output.svg", snapshot=None):
f.write(_memory(snapshot))
@deprecated(
"torch.cuda._set_allocator_settings is deprecated. Use torch._C._accelerator_setAllocatorSettings instead.",
category=FutureWarning,
)
def _set_allocator_settings(env: str):
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
return torch._C._accelerator_setAllocatorSettings(env)
def get_allocator_backend() -> str: