mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10` ``` % python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())" Apple M1 Pro 16 % system_profiler SPDisplaysDataType|head -n10 Graphics/Displays: Apple M1 Pro: Chipset Model: Apple M1 Pro Type: GPU Bus: Built-In Total Number of Cores: 16 Vendor: Apple (0x106b) Metal Support: Metal 3 ``` This would significantly improve occupancy for torch.compile generated kernels Pull Request resolved: https://github.com/pytorch/pytorch/pull/160414 Approved by: https://github.com/dcci
79 lines
2.1 KiB
Python
79 lines
2.1 KiB
Python
from functools import lru_cache as _lru_cache
|
|
from typing import Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.library import Library as _Library
|
|
|
|
|
|
__all__ = [
|
|
"get_core_count",
|
|
"get_name",
|
|
"is_built",
|
|
"is_available",
|
|
"is_macos13_or_newer",
|
|
"is_macos_or_newer",
|
|
]
|
|
|
|
|
|
def is_built() -> bool:
|
|
r"""Return whether PyTorch is built with MPS support.
|
|
|
|
Note that this doesn't necessarily mean MPS is available; just that
|
|
if this PyTorch binary were run a machine with working MPS drivers
|
|
and devices, we would be able to use it.
|
|
"""
|
|
return torch._C._has_mps
|
|
|
|
|
|
@_lru_cache
|
|
def is_available() -> bool:
|
|
r"""Return a bool indicating if MPS is currently available."""
|
|
return torch._C._mps_is_available()
|
|
|
|
|
|
@_lru_cache
|
|
def is_macos_or_newer(major: int, minor: int) -> bool:
|
|
r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
|
|
return torch._C._mps_is_on_macos_or_newer(major, minor)
|
|
|
|
|
|
@_lru_cache
|
|
def is_macos13_or_newer(minor: int = 0) -> bool:
|
|
r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
|
|
return torch._C._mps_is_on_macos_or_newer(13, minor)
|
|
|
|
|
|
@_lru_cache
|
|
def get_name() -> str:
|
|
r"""Return Metal device name"""
|
|
return torch._C._mps_get_name()
|
|
|
|
|
|
@_lru_cache
|
|
def get_core_count() -> int:
|
|
r"""Return GPU core count.
|
|
|
|
According to the documentation, one core is comprised of 16 Execution Units.
|
|
One execution Unit has 8 ALUs.
|
|
And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
|
|
"""
|
|
return torch._C._mps_get_core_count()
|
|
|
|
|
|
_lib: Optional[_Library] = None
|
|
|
|
|
|
def _init() -> None:
|
|
r"""Register prims as implementation of var_mean and group_norm."""
|
|
global _lib
|
|
|
|
if _lib is not None or not is_built():
|
|
return
|
|
|
|
from torch._decomp.decompositions import native_group_norm_backward
|
|
from torch._refs import native_group_norm
|
|
|
|
_lib = _Library("aten", "IMPL") # noqa: TOR901
|
|
_lib.impl("native_group_norm", native_group_norm, "MPS")
|
|
_lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")
|