Files
pytorch/torch/csrc/mtia/Module.cpp
anwang c55e72bea1 [Re-land][Inductor] Support native Inductor as backend for MTIA (#159211)
The previous [diff/PR] (https://github.com/pytorch/pytorch/pull/158526) was reverted due to this docstring lint error:
<img width="1736" height="722" alt="image" src="https://github.com/user-attachments/assets/216b1720-4002-48da-b5f3-32b5d48aaa54" />
I didn't add the docstring cause I thought I'm not supposed to add docstring for an EXISTING function.

So this diff/PR is an exactly copy of the previous one, except for adding the docstring.

-------------
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly.

The changes include:
- Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc.
- Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc.
- MTIA specific codegen logic, for example, loading MTIA dynamic_library.
- Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU.
- Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend.
- A change in Inductor runtime to avoid re-initialize MTIADriver.
- BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag.
- Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag.
- Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose.

Note:
- This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead.
- MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen.

Internal:
References:
- [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/)
- [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb)
- [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w)
- [early prototying diff](https://www.internalfb.com/diff/D75110196)
- [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959)
- [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678)

Differential Revision: [D79040806](https://our.internmc.facebook.com/intern/diff/D79040806/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159211
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/jansel
2025-07-29 17:03:24 +00:00

137 lines
4.5 KiB
C++

#include <ATen/ATen.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/mtia/Module.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::mtia {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_mtia_init", []() {
TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA));
torch::utils::register_fork_handler_for_device_init(at::kMTIA);
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
});
m.def("_mtia_isBuilt", []() {
// Check if the MTIAHooks class has been registered with the registry.
return at::detail::isMTIAHooksBuilt();
});
m.def("_mtia_isInBadFork", []() {
return torch::utils::is_device_in_bad_fork(at::kMTIA);
});
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentStream(device_index);
});
m.def("_mtia_getCurrentRawStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentRawStream(device_index);
});
m.def("_mtia_deviceSynchronize", []() {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().deviceSynchronize(
at::detail::getMTIAHooks().getCurrentDevice());
});
m.def("_mtia_exchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().exchangeDevice(device_index);
});
m.def("_mtia_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().maybeExchangeDevice(device_index);
});
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getDefaultStream(device_index);
});
m.def(
"_mtia_setStream",
[](int64_t stream_id,
c10::DeviceIndex device_index,
int64_t device_type) {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().setCurrentStream(c10::Stream::unpack3(
stream_id,
device_index,
static_cast<c10::DeviceType>(device_type)));
});
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
torch::utils::device_lazy_init(at::kMTIA);
auto device = at::detail::getMTIAHooks().getCurrentDevice();
if (device != stream.device_index()) {
at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
}
at::detail::getMTIAHooks().setCurrentStream(stream);
});
m.def("_mtia_memoryStats", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().memoryStats(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_getDeviceCapability", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().getDeviceCapability(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_getDeviceProperties", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().getDeviceProperties(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_emptyCache", []() { at::detail::getMTIAHooks().emptyCache(); });
m.def(
"_mtia_recordMemoryHistory",
[](const std::optional<std::string>& enabled,
const std::string& stacks,
size_t max_entries) {
at::detail::getMTIAHooks().recordMemoryHistory(
enabled, stacks, max_entries);
});
m.def("_mtia_memorySnapshot", []() {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().memorySnapshot(std::nullopt);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_attachOutOfMemoryObserver", [](const py::function& observer) {
at::detail::getMTIAHooks().attachOutOfMemoryObserver(observer.ptr());
return;
});
m.def("_mtia_getDeviceCount", []() {
return at::detail::getMTIAHooks().deviceCount();
});
m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
});
}
} // namespace torch::mtia