refactor cached tensor more generic (#129359)

# Motivation
solve https://github.com/pytorch/pytorch/issues/129027 to refactor cached tensor to be generic.

# Additional Context
No API name change. It is only decoupling with CUDA build option.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129359
Approved by: https://github.com/eqy, https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Yu, Guangye
2024-07-10 02:21:49 +00:00
committed by PyTorch MergeBot
parent c6aa03bd4e
commit f2552dcc3d
3 changed files with 32 additions and 17 deletions

View File

@ -10,6 +10,7 @@
#include <ATen/ATen.h>
#include <ATen/BlasBackend.h>
#include <ATen/CachedTensorUtils.h>
#include <ATen/DLConvertor.h>
#include <ATen/ExpandUtils.h>
#include <ATen/LegacyVmapMode.h>
@ -1785,6 +1786,22 @@ Call this whenever a new thread is created in order to propagate values from
:func:`torch.set_num_threads` onto the new thread.
)");
py_module.def("_set_cached_tensors_enabled", [](bool enabled) {
at::caching::set_cached_tensors_enabled(enabled);
});
py_module.def("_add_cached_tensor", [](const at::Tensor& t) {
at::caching::add_cached_tensor(t);
});
py_module.def("_remove_cached_tensor", [](const at::Tensor& t) {
at::caching::remove_cached_tensor(t);
});
py_module.def("_is_cached_tensor", [](const at::Tensor& t) {
return at::caching::is_cached_tensor(t);
});
ASSERT_TRUE(
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));