[MTIAGraph][Pytorch][2/n] Add binding for Python to C++, and hook for Pytorch to Fbcode (#165963)

Summary:
This diff is the binding and hook layer for MTIA Graph, including
1. binding between Python and C++
2. hook between Pytorch and mtia fbcode
<img width="1780" height="754" alt="image" src="https://github.com/user-attachments/assets/31e24e5b-8324-42d8-8d3b-59536bc18340" />

[Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb#heading=h.ayp9tkk08x00)

Test Plan: Will be tested in the python implementation which will use the binding and hook

Differential Revision: D84457757

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165963
Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
Andy (An) Wang
2025-10-31 02:52:48 +00:00
committed by PyTorch MergeBot
parent 1129605415
commit d3be06cbdc
2 changed files with 76 additions and 0 deletions

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
@ -151,6 +152,36 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
}
virtual bool isAvailable() const override;
/* MTIAGraph related APIs */
virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
FAIL_MTIAHOOKS_FUNC(__func__);
return -1;
}
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphCaptureEnd(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphInstantiate(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphReplay(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphReset(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual MempoolId_t mtiagraphPool(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -10,6 +10,42 @@
namespace torch::mtia {
struct _MTIAGraph {
// MTIA use accelerator hooks to connect pytorch and outside.
// We need to provide the MTIAGraph class at Python layer, but the hooks only
// support hooking functions, not classes. Thus we store all MTIAGraph C++
// instances in a map, and use a handle to choose the right instance.
int64_t handle_;
_MTIAGraph(bool keep_graph = false)
: handle_(at::detail::getMTIAHooks().mtiagraphCreate(keep_graph)) {}
~_MTIAGraph() = default;
void capture_begin(at::MempoolId_t pool) {
at::detail::getMTIAHooks().mtiagraphCaptureBegin(handle_, pool);
}
void capture_end() {
at::detail::getMTIAHooks().mtiagraphCaptureEnd(handle_);
}
void instantiate() {
at::detail::getMTIAHooks().mtiagraphInstantiate(handle_);
}
void replay() {
at::detail::getMTIAHooks().mtiagraphReplay(handle_);
}
void reset() {
at::detail::getMTIAHooks().mtiagraphReset(handle_);
}
at::MempoolId_t pool() {
return at::detail::getMTIAHooks().mtiagraphPool(handle_);
}
};
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
@ -131,6 +167,15 @@ void initModule(PyObject* module) {
m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
});
py::class_<_MTIAGraph>(m, "_MTIAGraph")
.def(py::init<bool>(), py::arg("keep_graph") = false)
.def("capture_begin", &_MTIAGraph::capture_begin)
.def("capture_end", &_MTIAGraph::capture_end)
.def("instantiate", &_MTIAGraph::instantiate)
.def("replay", &_MTIAGraph::replay)
.def("reset", &_MTIAGraph::reset)
.def("pool", &_MTIAGraph::pool);
}
} // namespace torch::mtia