mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Kineto] Enable OOM observer (#152160)
Summary: # Context: When memory leak happens, it usually trigger the OOM in the later iterations. The snapshot of full iteration will be huge and hard to interpret. On CUDA side, they provide OOM observer which generates snapshot when OOM happens with latest 1,500,000 entries for debugging. In this diff, we want to implement the feature on MTIA side Test Plan: Run this test with last diff in the stack. ``` buck run @//mode/opt kineto/libkineto/fb/mtia/integration_tests:mtia_memory_auto_trace_test ``` As shown, the memory_snapshot is generated when oom happens Log: P1794792326 Snapshot: https://fburl.com/pytorch_memory_visualizer/lx73y6s3 {F1977402355} Differential Revision: D71993315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152160 Approved by: https://github.com/sraikund16
This commit is contained in:
committed by
PyTorch MergeBot
parent
c4b0854750
commit
861945100e
@ -141,6 +141,10 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void attachOutOfMemoryObserver(PyObject* observer) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
|
@ -23,6 +23,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
||||
empty_cache
|
||||
record_memory_history
|
||||
snapshot
|
||||
attach_out_of_memory_observer
|
||||
set_device
|
||||
set_stream
|
||||
stream
|
||||
|
@ -1868,6 +1868,9 @@ def _mtia_recordMemoryHistory(
|
||||
max_entries
|
||||
) -> None: ...
|
||||
def _mtia_memorySnapshot() -> Dict[str, Any]: ...
|
||||
def _mtia_attachOutOfMemoryObserver(
|
||||
observer: Callable[[_int, _int, _int, _int], None]
|
||||
) -> None: ...
|
||||
def _mtia_getDeviceCount() -> _int: ...
|
||||
def _mtia_resetPeakMemoryStats(device: _int) -> None: ...
|
||||
|
||||
|
@ -100,6 +100,11 @@ void initModule(PyObject* module) {
|
||||
return py::reinterpret_steal<py::object>(raw_pyobject);
|
||||
});
|
||||
|
||||
m.def("_mtia_attachOutOfMemoryObserver", [](const py::function& observer) {
|
||||
at::detail::getMTIAHooks().attachOutOfMemoryObserver(observer.ptr());
|
||||
return;
|
||||
});
|
||||
|
||||
m.def("_mtia_getDeviceCount", []() {
|
||||
return at::detail::getMTIAHooks().deviceCount();
|
||||
});
|
||||
|
@ -197,6 +197,13 @@ def snapshot() -> dict[str, Any]:
|
||||
return torch._C._mtia_memorySnapshot()
|
||||
|
||||
|
||||
def attach_out_of_memory_observer(
|
||||
observer: Callable[[int, int, int, int], None]
|
||||
) -> None:
|
||||
r"""Attach an out-of-memory observer to MTIA memory allocator"""
|
||||
torch._C._mtia_attachOutOfMemoryObserver(observer)
|
||||
|
||||
|
||||
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
|
||||
r"""Return capability of a given device as a tuple of (major version, minor version).
|
||||
|
||||
@ -378,6 +385,7 @@ __all__ = [
|
||||
"get_device_capability",
|
||||
"record_memory_history",
|
||||
"snapshot",
|
||||
"attach_out_of_memory_observer",
|
||||
"empty_cache",
|
||||
"set_device",
|
||||
"set_stream",
|
||||
|
Reference in New Issue
Block a user