Files
pytorch/torch/mtia/memory.py
anwang c55e72bea1 [Re-land][Inductor] Support native Inductor as backend for MTIA (#159211)
The previous [diff/PR] (https://github.com/pytorch/pytorch/pull/158526) was reverted due to this docstring lint error:
<img width="1736" height="722" alt="image" src="https://github.com/user-attachments/assets/216b1720-4002-48da-b5f3-32b5d48aaa54" />
I didn't add the docstring cause I thought I'm not supposed to add docstring for an EXISTING function.

So this diff/PR is an exactly copy of the previous one, except for adding the docstring.

-------------
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly.

The changes include:
- Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc.
- Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc.
- MTIA specific codegen logic, for example, loading MTIA dynamic_library.
- Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU.
- Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend.
- A change in Inductor runtime to avoid re-initialize MTIADriver.
- BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag.
- Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag.
- Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose.

Note:
- This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead.
- MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen.

Internal:
References:
- [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/)
- [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb)
- [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w)
- [early prototying diff](https://www.internalfb.com/diff/D75110196)
- [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959)
- [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678)

Differential Revision: [D79040806](https://our.internmc.facebook.com/intern/diff/D79040806/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159211
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/jansel
2025-07-29 17:03:24 +00:00

72 lines
2.2 KiB
Python

# pyre-strict
r"""This package adds support for device memory management implemented in MTIA."""
from typing import Any, Optional
import torch
from . import _device_t, is_initialized
from ._utils import _get_device_index
def memory_stats(device: Optional[_device_t] = None) -> dict[str, Any]:
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
Args:
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return {}
return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
def max_memory_allocated(device: Optional[_device_t] = None) -> int:
r"""Return the maximum memory allocated in bytes for a given device.
Args:
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return 0
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
def memory_allocated(device: Optional[_device_t] = None) -> int:
r"""Return the current MTIA memory occupied by tensors in bytes for a given device.
Args:
device (torch.device or int or str, optional): selected device. Returns
statistic for the current device, given by :func:`~torch.mtia.current_device`,
if :attr:`device` is ``None`` (default).
"""
if not is_initialized():
return 0
return memory_stats(device).get("dram", 0).get("allocated_bytes", 0)
def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None:
r"""Reset the peak memory stats for a given device.
Args:
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
if not is_initialized():
return
torch._C._mtia_resetPeakMemoryStats(_get_device_index(device, optional=True))
__all__ = [
"memory_stats",
"max_memory_allocated",
"memory_allocated",
"reset_peak_memory_stats",
]