[Re-land][Inductor] Support native Inductor as backend for MTIA (#159211)

The previous [diff/PR] (https://github.com/pytorch/pytorch/pull/158526) was reverted due to this docstring lint error:
<img width="1736" height="722" alt="image" src="https://github.com/user-attachments/assets/216b1720-4002-48da-b5f3-32b5d48aaa54" />
I didn't add the docstring cause I thought I'm not supposed to add docstring for an EXISTING function.

So this diff/PR is an exactly copy of the previous one, except for adding the docstring.

-------------
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly.

The changes include:
- Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc.
- Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc.
- MTIA specific codegen logic, for example, loading MTIA dynamic_library.
- Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU.
- Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend.
- A change in Inductor runtime to avoid re-initialize MTIADriver.
- BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag.
- Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag.
- Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose.

Note:
- This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead.
- MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen.

Internal:
References:
- [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/)
- [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb)
- [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w)
- [early prototying diff](https://www.internalfb.com/diff/D75110196)
- [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959)
- [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678)

Differential Revision: [D79040806](https://our.internmc.facebook.com/intern/diff/D79040806/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159211
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/jansel
This commit is contained in:
anwang
2025-07-28 12:37:07 -07:00
committed by PyTorch MergeBot
parent 750348b579
commit c55e72bea1
18 changed files with 232 additions and 5 deletions

View File

@ -16,4 +16,5 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
:nosignatures:
memory_stats
memory_allocated
```

View File

@ -1948,8 +1948,10 @@ def _mtia_isBuilt() -> _bool: ...
def _mtia_isInBadFork() -> _bool: ...
def _mtia_deviceSynchronize() -> None: ...
def _mtia_getCurrentStream(device: _int) -> Stream: ...
def _mtia_getCurrentRawStream(device: _int) -> _int: ...
def _mtia_setCurrentStream(stream: Stream) -> None: ...
def _mtia_getDefaultStream(device: _int) -> Stream: ...
def _mtia_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
def _mtia_memoryStats(device: _int) -> dict[str, Any]: ...
def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ...
def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ...

View File

@ -2,10 +2,10 @@
Device abstraction layer for TorchDynamo and Inductor backends.
This module provides a unified interface for different hardware backends (CUDA, XPU,
CPU, MPS) through a common device interface. Key components include:
CPU, MPS, MTIA) through a common device interface. Key components include:
- DeviceInterface: Base class defining the common API for all device types
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface, MtiaInterface
- Device registration system for managing available backends
- Worker APIs for multi-processing scenarios
- Stream and event management across different devices
@ -287,6 +287,87 @@ class CudaInterface(DeviceInterface):
raise RuntimeError("triton not built with the 'nvidia' backend")
get_mtia_stream: Optional[Callable[[int], int]]
if torch.mtia._is_compiled():
from torch._C import _mtia_getCurrentRawStream as get_mtia_stream
else:
get_mtia_stream = None
class MtiaInterface(DeviceInterface):
device = torch.mtia.device # type: ignore[assignment]
Event = torch.mtia.Event # type: ignore[assignment]
Stream = torch.mtia.Stream # type: ignore[assignment]
class Worker:
@staticmethod
def set_device(device: int) -> None:
caching_worker_current_devices["mtia"] = device
@staticmethod
def current_device() -> int:
if "mtia" in caching_worker_current_devices:
return caching_worker_current_devices["mtia"]
return torch.mtia.current_device()
@staticmethod
def get_device_properties(device: torch.types.Device = None) -> Any:
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert device.type == "mtia"
if isinstance(device, torch.device):
device = device.index
if device is None:
device = MtiaInterface.Worker.current_device()
if "mtia" not in caching_worker_device_properties:
device_prop = [
torch.mtia.get_device_properties(i)
for i in range(torch.mtia.device_count())
]
caching_worker_device_properties["mtia"] = device_prop
return caching_worker_device_properties["mtia"][device]
current_device = staticmethod(torch.mtia.current_device)
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
device_count = staticmethod(torch.mtia.device_count)
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
current_stream = staticmethod(torch.mtia.current_stream)
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
synchronize = staticmethod(torch.mtia.synchronize)
get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type]
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type]
memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
# Can be mock patched by @patch decorator.
@staticmethod
def is_available() -> bool:
ret = torch.mtia.is_available()
return ret
@staticmethod
def get_compute_capability(device: torch.types.Device = None) -> Any:
cc = torch.mtia.get_device_capability(device)
return cc
@staticmethod
def is_triton_capable(device: torch.types.Device = None) -> bool:
return True
@staticmethod
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
import triton.backends
if "mtia" not in triton.backends.backends:
raise RuntimeError("triton not built with the 'mtia' backend")
get_xpu_stream: Optional[Callable[[int], int]]
if torch.xpu._is_compiled():
from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
@ -509,6 +590,10 @@ def init_device_reg() -> None:
for i in range(torch.xpu.device_count()):
register_interface_for_device(f"xpu:{i}", XpuInterface)
register_interface_for_device("mtia", MtiaInterface)
for i in range(torch.mtia.device_count()):
register_interface_for_device(f"mtia:{i}", MtiaInterface)
register_interface_for_device("cpu", CpuInterface)
register_interface_for_device("mps", MpsInterface)

View File

@ -3967,7 +3967,7 @@ def is_compile_supported(device_type):
compile_supported = is_dynamo_supported()
if type == "cpu":
pass
elif type in ["cuda", "xpu"] and compile_supported:
elif type in ["cuda", "xpu", "mtia"] and compile_supported:
compile_supported = has_triton()
else:
compile_supported = False

View File

@ -16,6 +16,7 @@ from .template_heuristics import (
BaseConfigHeuristic,
CPUConfigHeuristic,
CUDAConfigHeuristic,
MTIAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
@ -65,6 +66,8 @@ class InductorChoices:
return XPUConfigHeuristic()
elif device_type == "cpu":
return CPUConfigHeuristic()
elif device_type == "mtia":
return MTIAConfigHeuristic()
else:
return BaseConfigHeuristic()

View File

@ -484,6 +484,10 @@ def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
@functools.cache
def init_backend_registration() -> None:
"""
Register the backend for different devices, including the scheduling
for kernel code generation and the host side wrapper code generation.
"""
from .cpp import CppScheduling
from .cpp_wrapper_cpu import CppWrapperCpu
from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef
@ -492,6 +496,7 @@ def init_backend_registration() -> None:
from .cuda_combined_scheduling import CUDACombinedScheduling
from .halide import HalideScheduling
from .mps import MetalScheduling
from .python_wrapper_mtia import PythonWrapperMtia
from .triton import TritonScheduling
from .wrapper import PythonWrapperCodegen
@ -539,6 +544,14 @@ def init_backend_registration() -> None:
CppWrapperMps,
)
if get_scheduling_for_device("mtia") is None:
register_backend_for_device(
"mtia",
TritonScheduling,
PythonWrapperMtia,
CppWrapperGpu,
)
private_backend = torch._C._get_privateuse1_backend_name()
if (
private_backend != "privateuseone"
@ -584,6 +597,7 @@ def get_device_op_overrides(device: str) -> DeviceOpOverrides:
if not device_op_overrides_dict:
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
from .cuda import device_op_overrides # noqa: F401
from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
return device_op_overrides_dict[device]

View File

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from ..common import DeviceOpOverrides, register_device_op_overrides
class MTIADeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name: str) -> str:
return f"from torch._C import _mtia_getCurrentRawStream as {name}"
def set_device(self, device_idx: int) -> str:
return f"torch.mtia.set_device({device_idx})"
def synchronize(self) -> str:
return "torch.mtia.synchronize()"
def device_guard(self, device_idx: int) -> str:
return f"torch.mtia.device({device_idx})"
register_device_op_overrides("mtia", MTIADeviceOpOverrides())

View File

@ -0,0 +1,34 @@
from typing import Optional
from typing_extensions import override
from torch._inductor import ir
from .wrapper import PythonWrapperCodegen
class PythonWrapperMtia(PythonWrapperCodegen):
"""
A thin wrapper of PythonWrapperCodegen with MTIA specific logic
"""
@override
def write_header(self) -> None:
super().write_header()
# MITA specific imports
self.imports.splice("import mtia.host_runtime.torch_mtia.dynamic_library")
@override
@staticmethod
def create(
is_subgraph: bool,
subgraph_name: Optional[str],
parent_wrapper: Optional[PythonWrapperCodegen],
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
) -> PythonWrapperCodegen:
if is_subgraph:
# Delegate to the parent class to handle the case of subgraph
return PythonWrapperCodegen.create(
is_subgraph, subgraph_name, parent_wrapper, partition_signatures
)
return PythonWrapperMtia()

View File

@ -997,6 +997,7 @@ class PythonWrapperCodegen(CodeGen):
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
@ -2777,7 +2778,7 @@ class PythonWrapperCodegen(CodeGen):
allocation_shape
)
codegen_stride_tuple = self.codegen_python_shape_tuple(stride)
if device.type in ("cpu", "cuda", "xpu"):
if device.type in ("cpu", "cuda", "xpu", "mtia"):
# optimized path for faster allocations, saving ~2us versus the stuff below
out = (
f"{name} = empty_strided_{device.type}("

View File

@ -156,6 +156,8 @@ class DeviceProperties(typing.NamedTuple):
elif device_type == "mps":
# TODO: Fetch the actual value from ioreg
multi_processor_count = 8
elif device_type == "mtia":
multi_processor_count = 64
else:
raise
return cls(

View File

@ -1201,3 +1201,9 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
"""
Placeholder child class for XPU specific overrides.
"""
class MTIAConfigHeuristic(BaseConfigHeuristic):
"""
Placeholder child class for MTIA specific overrides.
"""

View File

@ -94,7 +94,7 @@ if TYPE_CHECKING:
from .scheduler import BaseSchedulerNode, SchedulerBuffer
GPU_TYPES = ["cuda", "mps", "xpu"]
GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
T = TypeVar("T")

View File

@ -31,6 +31,10 @@
#include <ATen/xpu/EmptyTensor.h>
#endif
#ifdef USE_MTIA
#include <ATen/native/mtia/EmptyTensor.h>
#endif
#include <chrono>
#include <sstream>
#include <tuple>
@ -1059,6 +1063,12 @@ static PyObject* _empty_strided_device(
return THPVariable_Wrap(at::detail::empty_strided_xpu(
sizes, strides, dtype, c10::DeviceType::XPU));
}
#endif
#ifdef USE_MTIA
else if (device_type == c10::DeviceType::MTIA) {
return THPVariable_Wrap(at::detail::empty_strided_mtia(
sizes, strides, dtype, c10::DeviceType::MTIA));
}
#endif
else {
TORCH_CHECK(
@ -1084,6 +1094,10 @@ static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
}
static PyObject* _empty_strided_mtia(PyObject* dummy, PyObject* args) {
return _empty_strided_device(dummy, args, c10::DeviceType::MTIA);
}
static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
static PythonArgParser parser(
@ -1115,6 +1129,7 @@ static PyMethodDef _methods[] = {
{"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
{"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
{"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
{"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr},
{"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};

View File

@ -63,6 +63,18 @@ void initModule(PyObject* module) {
return at::detail::getMTIAHooks().getDefaultStream(device_index);
});
m.def(
"_mtia_setStream",
[](int64_t stream_id,
c10::DeviceIndex device_index,
int64_t device_type) {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().setCurrentStream(c10::Stream::unpack3(
stream_id,
device_index,
static_cast<c10::DeviceType>(device_type)));
});
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
torch::utils::device_lazy_init(at::kMTIA);
auto device = at::detail::getMTIAHooks().getCurrentDevice();

View File

@ -204,6 +204,10 @@ def attach_out_of_memory_observer(
torch._C._mtia_attachOutOfMemoryObserver(observer)
def is_bf16_supported(including_emulation: bool = True):
return True
def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]:
r"""Return capability of a given device as a tuple of (major version, minor version).
@ -335,6 +339,17 @@ class StreamContext:
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
def _set_stream_by_id(stream_id, device_index, device_type):
r"""set stream specified by the stream id, device index and
device type
Args: stream_id (int): stream id in stream pool
device_index (int): device index in topo
device_type (int): enum device type
"""
torch._C._mtia_setStream(stream_id, device_index, device_type)
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
r"""Wrap around the Context-manager StreamContext that selects a given stream.
@ -392,6 +407,7 @@ __all__ = [
"default_stream",
"memory_stats",
"max_memory_allocated",
"memory_allocated",
"reset_peak_memory_stats",
"get_device_capability",
"get_device_properties",
@ -405,4 +421,5 @@ __all__ = [
"device",
"set_rng_state",
"get_rng_state",
"is_bf16_supported",
]

View File

@ -36,6 +36,19 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int:
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
def memory_allocated(device: Optional[_device_t] = None) -> int:
r"""Return the current MTIA memory occupied by tensors in bytes for a given device.
Args:
device (torch.device or int or str, optional): selected device. Returns
statistic for the current device, given by :func:`~torch.mtia.current_device`,
if :attr:`device` is ``None`` (default).
"""
if not is_initialized():
return 0
return memory_stats(device).get("dram", 0).get("allocated_bytes", 0)
def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
r"""Reset the peak memory stats for a given device.
@ -53,5 +66,6 @@ def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
__all__ = [
"memory_stats",
"max_memory_allocated",
"memory_allocated",
"reset_peak_memory_stats",
]

View File

@ -135,6 +135,7 @@ def has_triton() -> bool:
"cuda": cuda_extra_check,
"xpu": _return_true,
"cpu": cpu_extra_check,
"mtia": _return_true,
}
def is_device_compatible_with_triton() -> bool: