Files
pytorch/torch/csrc/mtia/Module.cpp
Zizeng Meng 861945100e [Kineto] Enable OOM observer (#152160)
Summary:
# Context:
When memory leak happens, it usually trigger the OOM in the later iterations. The snapshot of full iteration will be huge and hard to interpret.
On CUDA side, they provide OOM observer which generates snapshot when OOM happens with latest 1,500,000 entries for debugging.

In this diff, we want to implement the feature on MTIA side

Test Plan:
Run this test with last diff in the stack.
```
buck run @//mode/opt  kineto/libkineto/fb/mtia/integration_tests:mtia_memory_auto_trace_test
```

As shown, the memory_snapshot is generated when oom happens
Log: P1794792326
Snapshot: https://fburl.com/pytorch_memory_visualizer/lx73y6s3 {F1977402355}

Differential Revision: D71993315

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152160
Approved by: https://github.com/sraikund16
2025-04-27 15:56:44 +00:00

118 lines
3.9 KiB
C++

#include <ATen/ATen.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/mtia/Module.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::mtia {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_mtia_init", []() {
TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA));
torch::utils::register_fork_handler_for_device_init(at::kMTIA);
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
});
m.def("_mtia_isBuilt", []() {
// Check if the MTIAHooks class has been registered with the registry.
return at::detail::isMTIAHooksBuilt();
});
m.def("_mtia_isInBadFork", []() {
return torch::utils::is_device_in_bad_fork(at::kMTIA);
});
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentStream(device_index);
});
m.def("_mtia_getCurrentRawStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentRawStream(device_index);
});
m.def("_mtia_deviceSynchronize", []() {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().deviceSynchronize(
at::detail::getMTIAHooks().getCurrentDevice());
});
m.def("_mtia_exchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().exchangeDevice(device_index);
});
m.def("_mtia_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().maybeExchangeDevice(device_index);
});
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getDefaultStream(device_index);
});
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
torch::utils::device_lazy_init(at::kMTIA);
auto device = at::detail::getMTIAHooks().getCurrentDevice();
if (device != stream.device_index()) {
at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
}
at::detail::getMTIAHooks().setCurrentStream(stream);
});
m.def("_mtia_memoryStats", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().memoryStats(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_getDeviceCapability", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().getDeviceCapability(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_emptyCache", []() { at::detail::getMTIAHooks().emptyCache(); });
m.def(
"_mtia_recordMemoryHistory",
[](const std::optional<std::string>& enabled,
const std::string& stacks,
size_t max_entries) {
at::detail::getMTIAHooks().recordMemoryHistory(
enabled, stacks, max_entries);
});
m.def("_mtia_memorySnapshot", []() {
PyObject* raw_pyobject = at::detail::getMTIAHooks().memorySnapshot();
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_attachOutOfMemoryObserver", [](const py::function& observer) {
at::detail::getMTIAHooks().attachOutOfMemoryObserver(observer.ptr());
return;
});
m.def("_mtia_getDeviceCount", []() {
return at::detail::getMTIAHooks().deviceCount();
});
m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
});
}
} // namespace torch::mtia