Introduce a generic API torch._C._accelerator_setAllocatorSettings (#165291)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165291
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289
This commit is contained in:
Yu, Guangye
2025-10-17 17:16:44 +00:00
committed by PyTorch MergeBot
parent a1114beed2
commit b2f5c25b27
6 changed files with 23 additions and 26 deletions

View File

@ -1,3 +1,4 @@
#include <c10/core/AllocatorConfig.h>
#include <torch/csrc/DeviceAccelerator.h>
#include <torch/csrc/utils/device_lazy_init.h>
@ -136,6 +137,10 @@ void initModule(PyObject* module) {
m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) {
at::accelerator::resetPeakStats(device_index);
});
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
c10::CachingAllocator::setAllocatorSettings(env);
});
}
} // namespace torch::accelerator

View File

@ -20,8 +20,8 @@
#include <ATen/cuda/detail/CUDAHooks.h>
#include <ATen/cuda/jiterator.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <c10/core/AllocatorConfig.h>
#include <c10/core/StorageImpl.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
@ -422,16 +422,6 @@ PyObject* THCPModule_cudaCachingAllocator_enable(
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
PyObject* _unused,
PyObject* env) {
HANDLE_TH_ERRORS
c10::cuda::CUDACachingAllocator::setAllocatorSettings(
THPUtils_unpackString(env));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
return THPUtils_packString(c10::cuda::CUDACachingAllocator::name());
@ -2077,10 +2067,6 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_cudaCachingAllocator_enable,
METH_O,
nullptr},
{"_cuda_cudaCachingAllocator_set_allocator_settings",
THCPModule_cudaCachingAllocator_set_allocator_settings,
METH_O,
nullptr},
{"_cuda_getAllocatorBackend",
THCPModule_getAllocatorBackend,
METH_NOARGS,