mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
refactor cached tensor more generic (#129359)
# Motivation solve https://github.com/pytorch/pytorch/issues/129027 to refactor cached tensor to be generic. # Additional Context No API name change. It is only decoupling with CUDA build option. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129359 Approved by: https://github.com/eqy, https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6aa03bd4e
commit
f2552dcc3d
@ -10,6 +10,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
#include <ATen/DLConvertor.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/LegacyVmapMode.h>
|
||||
@ -1785,6 +1786,22 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
:func:`torch.set_num_threads` onto the new thread.
|
||||
)");
|
||||
|
||||
py_module.def("_set_cached_tensors_enabled", [](bool enabled) {
|
||||
at::caching::set_cached_tensors_enabled(enabled);
|
||||
});
|
||||
|
||||
py_module.def("_add_cached_tensor", [](const at::Tensor& t) {
|
||||
at::caching::add_cached_tensor(t);
|
||||
});
|
||||
|
||||
py_module.def("_remove_cached_tensor", [](const at::Tensor& t) {
|
||||
at::caching::remove_cached_tensor(t);
|
||||
});
|
||||
|
||||
py_module.def("_is_cached_tensor", [](const at::Tensor& t) {
|
||||
return at::caching::is_cached_tensor(t);
|
||||
});
|
||||
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
|
||||
|
Reference in New Issue
Block a user