mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Land Internally] MTIA equivalent of torch.cuda.memory_stats (#132007)
Summary: as title Test Plan: pytorch ci failing: https://github.com/pytorch/pytorch/issues/131962 Differential Revision: D60335413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132007 Approved by: https://github.com/hanzlfs, https://github.com/egienvalue
This commit is contained in:
committed by
PyTorch MergeBot
parent
082d0b80ca
commit
dcb03106b7
@ -8,6 +8,7 @@
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <string>
|
||||
@ -17,7 +18,6 @@ class Context;
|
||||
}
|
||||
|
||||
namespace at {
|
||||
|
||||
constexpr const char* MTIA_HELP =
|
||||
"The MTIA backend requires MTIA extension for PyTorch;"
|
||||
"this error has occurred because you are trying "
|
||||
@ -99,6 +99,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual PyObject* memoryStats(DeviceIndex device) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
|
@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
||||
init
|
||||
is_available
|
||||
is_initialized
|
||||
memory_stats
|
||||
set_device
|
||||
set_stream
|
||||
stream
|
||||
|
@ -1772,6 +1772,7 @@ def _mtia_deviceSynchronize() -> None: ...
|
||||
def _mtia_getCurrentStream(device: _int) -> Stream: ...
|
||||
def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
||||
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
||||
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
# Defined in torch/csrc/mps/Module.cpp
|
||||
|
@ -1,13 +1,12 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/Stream.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#ifndef WIN32
|
||||
#include <pthread.h>
|
||||
#endif
|
||||
@ -75,6 +74,12 @@ void initModule(PyObject* module) {
|
||||
}
|
||||
at::detail::getMTIAHooks().setCurrentStream(stream);
|
||||
});
|
||||
|
||||
m.def("_mtia_memoryStats", [](c10::DeviceIndex device_index) {
|
||||
PyObject* raw_pyobject =
|
||||
at::detail::getMTIAHooks().memoryStats(device_index);
|
||||
return py::reinterpret_steal<py::object>(raw_pyobject);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mtia
|
||||
|
@ -48,8 +48,8 @@ def _lazy_init() -> None:
|
||||
if is_initialized() or hasattr(_tls, "is_initializing"):
|
||||
return
|
||||
with _initialization_lock:
|
||||
# We be double-checked locking, boys! This is OK because
|
||||
# the above test was GIL protected anyway. The inner test
|
||||
# We be double-checking locking, boys! This is OK because
|
||||
# the above test was GIL protected anyway. The inner test
|
||||
# is for when a thread blocked on some other thread which was
|
||||
# doing the initialization; when they get the lock, they will
|
||||
# find there is nothing left to do.
|
||||
@ -148,6 +148,19 @@ def default_stream(device: Optional[_device_t] = None) -> Stream:
|
||||
return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
|
||||
|
||||
|
||||
def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
|
||||
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional) selected device. Returns
|
||||
statistics for the current device, given by current_device(),
|
||||
if device is None (default).
|
||||
"""
|
||||
if not is_initialized():
|
||||
return {}
|
||||
return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
|
||||
|
||||
|
||||
def set_stream(stream: Stream):
|
||||
r"""Set the current stream.This is a wrapper API to set the stream.
|
||||
Usage of this function is discouraged in favor of the ``stream``
|
||||
@ -209,6 +222,7 @@ class StreamContext:
|
||||
cur_stream: Optional["torch.mtia.Stream"]
|
||||
|
||||
def __init__(self, stream: Optional["torch.mtia.Stream"]):
|
||||
self.cur_stream = None
|
||||
self.stream = stream
|
||||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
@ -304,6 +318,7 @@ __all__ = [
|
||||
"current_device",
|
||||
"current_stream",
|
||||
"default_stream",
|
||||
"memory_stats",
|
||||
"set_device",
|
||||
"set_stream",
|
||||
"stream",
|
||||
|
Reference in New Issue
Block a user