Files
pytorch/torch/backends/mps/__init__.py
Nikita Shulga a06ec54d40 [MPS] Add API to query GPU core count (#160414)
Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service
Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor

Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10`
```
% python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"
Apple M1 Pro 16
% system_profiler SPDisplaysDataType|head -n10
Graphics/Displays:

    Apple M1 Pro:

      Chipset Model: Apple M1 Pro
      Type: GPU
      Bus: Built-In
      Total Number of Cores: 16
      Vendor: Apple (0x106b)
      Metal Support: Metal 3
```

This would significantly improve occupancy for torch.compile generated kernels

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160414
Approved by: https://github.com/dcci
2025-08-14 00:05:17 +00:00

79 lines
2.1 KiB
Python

from functools import lru_cache as _lru_cache
from typing import Optional, TYPE_CHECKING
import torch
from torch.library import Library as _Library
__all__ = [
"get_core_count",
"get_name",
"is_built",
"is_available",
"is_macos13_or_newer",
"is_macos_or_newer",
]
def is_built() -> bool:
r"""Return whether PyTorch is built with MPS support.
Note that this doesn't necessarily mean MPS is available; just that
if this PyTorch binary were run a machine with working MPS drivers
and devices, we would be able to use it.
"""
return torch._C._has_mps
@_lru_cache
def is_available() -> bool:
r"""Return a bool indicating if MPS is currently available."""
return torch._C._mps_is_available()
@_lru_cache
def is_macos_or_newer(major: int, minor: int) -> bool:
r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
return torch._C._mps_is_on_macos_or_newer(major, minor)
@_lru_cache
def is_macos13_or_newer(minor: int = 0) -> bool:
r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
return torch._C._mps_is_on_macos_or_newer(13, minor)
@_lru_cache
def get_name() -> str:
r"""Return Metal device name"""
return torch._C._mps_get_name()
@_lru_cache
def get_core_count() -> int:
r"""Return GPU core count.
According to the documentation, one core is comprised of 16 Execution Units.
One execution Unit has 8 ALUs.
And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
"""
return torch._C._mps_get_core_count()
_lib: Optional[_Library] = None
def _init() -> None:
r"""Register prims as implementation of var_mean and group_norm."""
global _lib
if _lib is not None or not is_built():
return
from torch._decomp.decompositions import native_group_norm_backward
from torch._refs import native_group_norm
_lib = _Library("aten", "IMPL") # noqa: TOR901
_lib.impl("native_group_norm", native_group_norm, "MPS")
_lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")