mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[MTIAGraph][Pytorch][2/n] Add binding for Python to C++, and hook for Pytorch to Fbcode (#165963)
Summary: This diff is the binding and hook layer for MTIA Graph, including 1. binding between Python and C++ 2. hook between Pytorch and mtia fbcode <img width="1780" height="754" alt="image" src="https://github.com/user-attachments/assets/31e24e5b-8324-42d8-8d3b-59536bc18340" /> [Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb#heading=h.ayp9tkk08x00) Test Plan: Will be tested in the python implementation which will use the binding and hook Differential Revision: D84457757 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165963 Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							1129605415
						
					
				
				
					commit
					d3be06cbdc
				
			@ -1,5 +1,6 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/core/CachingDeviceAllocator.h>
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
 | 
			
		||||
@ -151,6 +152,36 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool isAvailable() const override;
 | 
			
		||||
 | 
			
		||||
  /* MTIAGraph related APIs */
 | 
			
		||||
  virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
    return -1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphCaptureEnd(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphInstantiate(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphReplay(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphReset(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual MempoolId_t mtiagraphPool(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API MTIAHooksArgs {};
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,42 @@
 | 
			
		||||
 | 
			
		||||
namespace torch::mtia {
 | 
			
		||||
 | 
			
		||||
struct _MTIAGraph {
 | 
			
		||||
  // MTIA use accelerator hooks to connect pytorch and outside.
 | 
			
		||||
  // We need to provide the MTIAGraph class at Python layer, but the hooks only
 | 
			
		||||
  // support hooking functions, not classes. Thus we store all MTIAGraph C++
 | 
			
		||||
  // instances in a map, and use a handle to choose the right instance.
 | 
			
		||||
  int64_t handle_;
 | 
			
		||||
 | 
			
		||||
  _MTIAGraph(bool keep_graph = false)
 | 
			
		||||
      : handle_(at::detail::getMTIAHooks().mtiagraphCreate(keep_graph)) {}
 | 
			
		||||
  ~_MTIAGraph() = default;
 | 
			
		||||
 | 
			
		||||
  void capture_begin(at::MempoolId_t pool) {
 | 
			
		||||
    at::detail::getMTIAHooks().mtiagraphCaptureBegin(handle_, pool);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void capture_end() {
 | 
			
		||||
    at::detail::getMTIAHooks().mtiagraphCaptureEnd(handle_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void instantiate() {
 | 
			
		||||
    at::detail::getMTIAHooks().mtiagraphInstantiate(handle_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void replay() {
 | 
			
		||||
    at::detail::getMTIAHooks().mtiagraphReplay(handle_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void reset() {
 | 
			
		||||
    at::detail::getMTIAHooks().mtiagraphReset(handle_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  at::MempoolId_t pool() {
 | 
			
		||||
    return at::detail::getMTIAHooks().mtiagraphPool(handle_);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void initModule(PyObject* module) {
 | 
			
		||||
  auto m = py::handle(module).cast<py::module>();
 | 
			
		||||
 | 
			
		||||
@ -131,6 +167,15 @@ void initModule(PyObject* module) {
 | 
			
		||||
  m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
 | 
			
		||||
    at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  py::class_<_MTIAGraph>(m, "_MTIAGraph")
 | 
			
		||||
      .def(py::init<bool>(), py::arg("keep_graph") = false)
 | 
			
		||||
      .def("capture_begin", &_MTIAGraph::capture_begin)
 | 
			
		||||
      .def("capture_end", &_MTIAGraph::capture_end)
 | 
			
		||||
      .def("instantiate", &_MTIAGraph::instantiate)
 | 
			
		||||
      .def("replay", &_MTIAGraph::replay)
 | 
			
		||||
      .def("reset", &_MTIAGraph::reset)
 | 
			
		||||
      .def("pool", &_MTIAGraph::pool);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace torch::mtia
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user