[Land Internally] MTIA equivalent of torch.cuda.memory_stats (#132007)

Summary: as title

Test Plan: pytorch ci failing: https://github.com/pytorch/pytorch/issues/131962

Differential Revision: D60335413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132007
Approved by: https://github.com/hanzlfs, https://github.com/egienvalue
This commit is contained in:
Simon Mahns
2024-07-29 20:47:18 +00:00
committed by PyTorch MergeBot
parent 082d0b80ca
commit dcb03106b7
5 changed files with 33 additions and 6 deletions

View File

@ -8,6 +8,7 @@
#include <c10/core/Allocator.h>
#include <c10/util/python_stub.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <string>
@ -17,7 +18,6 @@ class Context;
}
namespace at {
constexpr const char* MTIA_HELP =
"The MTIA backend requires MTIA extension for PyTorch;"
"this error has occurred because you are trying "
@ -99,6 +99,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
virtual PyObject* memoryStats(DeviceIndex device) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return nullptr;
}
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
init
is_available
is_initialized
memory_stats
set_device
set_stream
stream

View File

@ -1772,6 +1772,7 @@ def _mtia_deviceSynchronize() -> None: ...
def _mtia_getCurrentStream(device: _int) -> Stream: ...
def _mtia_setCurrentStream(stream: Stream) -> None: ...
def _mtia_getDefaultStream(device: _int) -> Stream: ...
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
# Defined in torch/csrc/mps/Module.cpp

View File

@ -1,13 +1,12 @@
#include <ATen/ATen.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/util/CallOnce.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#ifndef WIN32
#include <pthread.h>
#endif
@ -75,6 +74,12 @@ void initModule(PyObject* module) {
}
at::detail::getMTIAHooks().setCurrentStream(stream);
});
m.def("_mtia_memoryStats", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().memoryStats(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
}
} // namespace mtia

View File

@ -48,8 +48,8 @@ def _lazy_init() -> None:
if is_initialized() or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
# We be double-checked locking, boys! This is OK because
# the above test was GIL protected anyway. The inner test
# We be double-checking locking, boys! This is OK because
# the above test was GIL protected anyway. The inner test
# is for when a thread blocked on some other thread which was
# doing the initialization; when they get the lock, they will
# find there is nothing left to do.
@ -148,6 +148,19 @@ def default_stream(device: Optional[_device_t] = None) -> Stream:
return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
Args:
device (torch.device or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return {}
return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
def set_stream(stream: Stream):
r"""Set the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
@ -209,6 +222,7 @@ class StreamContext:
cur_stream: Optional["torch.mtia.Stream"]
def __init__(self, stream: Optional["torch.mtia.Stream"]):
self.cur_stream = None
self.stream = stream
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
@ -304,6 +318,7 @@ __all__ = [
"current_device",
"current_stream",
"default_stream",
"memory_stats",
"set_device",
"set_stream",
"stream",