[Re-land][Inductor] Support native Inductor as backend for MTIA (#159211)

The previous [diff/PR] (https://github.com/pytorch/pytorch/pull/158526) was reverted due to this docstring lint error:
<img width="1736" height="722" alt="image" src="https://github.com/user-attachments/assets/216b1720-4002-48da-b5f3-32b5d48aaa54" />
I didn't add the docstring cause I thought I'm not supposed to add docstring for an EXISTING function.

So this diff/PR is an exactly copy of the previous one, except for adding the docstring.

-------------
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly.

The changes include:
- Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc.
- Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc.
- MTIA specific codegen logic, for example, loading MTIA dynamic_library.
- Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU.
- Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend.
- A change in Inductor runtime to avoid re-initialize MTIADriver.
- BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag.
- Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag.
- Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose.

Note:
- This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead.
- MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen.

Internal:
References:
- [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/)
- [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb)
- [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w)
- [early prototying diff](https://www.internalfb.com/diff/D75110196)
- [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959)
- [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678)

Differential Revision: [D79040806](https://our.internmc.facebook.com/intern/diff/D79040806/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159211
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/jansel
This commit is contained in:
anwang
2025-07-28 12:37:07 -07:00
committed by PyTorch MergeBot
parent 750348b579
commit c55e72bea1
18 changed files with 232 additions and 5 deletions

View File

@ -484,6 +484,10 @@ def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
@functools.cache
def init_backend_registration() -> None:
"""
Register the backend for different devices, including the scheduling
for kernel code generation and the host side wrapper code generation.
"""
from .cpp import CppScheduling
from .cpp_wrapper_cpu import CppWrapperCpu
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
@ -492,6 +496,7 @@ def init_backend_registration() -> None:
from .cuda_combined_scheduling import CUDACombinedScheduling
from .halide import HalideScheduling
from .mps import MetalScheduling
from .python_wrapper_mtia import PythonWrapperMtia
from .triton import TritonScheduling
from .wrapper import PythonWrapperCodegen
@ -539,6 +544,14 @@ def init_backend_registration() -> None:
CppWrapperMps,
)
if get_scheduling_for_device("mtia") is None:
register_backend_for_device(
"mtia",
TritonScheduling,
PythonWrapperMtia,
CppWrapperGpu,
)
private_backend = torch._C._get_privateuse1_backend_name()
if (
private_backend != "privateuseone"
@ -584,6 +597,7 @@ def get_device_op_overrides(device: str) -> DeviceOpOverrides:
if not device_op_overrides_dict:
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
from .cuda import device_op_overrides # noqa: F401
from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
return device_op_overrides_dict[device]