[Kineto] Enable OOM observer (#152160)

Summary:
# Context:
When memory leak happens, it usually trigger the OOM in the later iterations. The snapshot of full iteration will be huge and hard to interpret.
On CUDA side, they provide OOM observer which generates snapshot when OOM happens with latest 1,500,000 entries for debugging.

In this diff, we want to implement the feature on MTIA side

Test Plan:
Run this test with last diff in the stack.
```
buck run @//mode/opt  kineto/libkineto/fb/mtia/integration_tests:mtia_memory_auto_trace_test
```

As shown, the memory_snapshot is generated when oom happens
Log: P1794792326
Snapshot: https://fburl.com/pytorch_memory_visualizer/lx73y6s3 {F1977402355}

Differential Revision: D71993315

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152160
Approved by: https://github.com/sraikund16
This commit is contained in:
Zizeng Meng
2025-04-27 15:56:41 +00:00
committed by PyTorch MergeBot
parent c4b0854750
commit 861945100e
5 changed files with 21 additions and 0 deletions

View File

@ -141,6 +141,10 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void attachOutOfMemoryObserver(PyObject* observer) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return;
}
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -23,6 +23,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
empty_cache
record_memory_history
snapshot
attach_out_of_memory_observer
set_device
set_stream
stream

View File

@ -1868,6 +1868,9 @@ def _mtia_recordMemoryHistory(
max_entries
) -> None: ...
def _mtia_memorySnapshot() -> Dict[str, Any]: ...
def _mtia_attachOutOfMemoryObserver(
observer: Callable[[_int, _int, _int, _int], None]
) -> None: ...
def _mtia_getDeviceCount() -> _int: ...
def _mtia_resetPeakMemoryStats(device: _int) -> None: ...

View File

@ -100,6 +100,11 @@ void initModule(PyObject* module) {
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_attachOutOfMemoryObserver", [](const py::function& observer) {
at::detail::getMTIAHooks().attachOutOfMemoryObserver(observer.ptr());
return;
});
m.def("_mtia_getDeviceCount", []() {
return at::detail::getMTIAHooks().deviceCount();
});

View File

@ -197,6 +197,13 @@ def snapshot() -> dict[str, Any]:
return torch._C._mtia_memorySnapshot()
def attach_out_of_memory_observer(
observer: Callable[[int, int, int, int], None]
) -> None:
r"""Attach an out-of-memory observer to MTIA memory allocator"""
torch._C._mtia_attachOutOfMemoryObserver(observer)
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
r"""Return capability of a given device as a tuple of (major version, minor version).
@ -378,6 +385,7 @@ __all__ = [
"get_device_capability",
"record_memory_history",
"snapshot",
"attach_out_of_memory_observer",
"empty_cache",
"set_device",
"set_stream",