[MPS] Add API to query GPU core count (#160414)

Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service
Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor

Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10`
```
% python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"
Apple M1 Pro 16
% system_profiler SPDisplaysDataType|head -n10
Graphics/Displays:

    Apple M1 Pro:

      Chipset Model: Apple M1 Pro
      Type: GPU
      Bus: Built-In
      Total Number of Cores: 16
      Vendor: Apple (0x106b)
      Metal Support: Metal 3
```

This would significantly improve occupancy for torch.compile generated kernels

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160414
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-08-13 14:57:18 -07:00
committed by PyTorch MergeBot
parent 50a8c11875
commit a06ec54d40
9 changed files with 77 additions and 9 deletions

View File

@ -1196,7 +1196,7 @@ if(APPLE)
string(
APPEND
CMAKE_SHARED_LINKER_FLAGS
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal"
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal -weak_framework IOKit"
)
# To suppress MPSGraph availability warnings
append_cxx_flag_if_supported("-Wno-unguarded-availability-new"

View File

@ -55,6 +55,17 @@ class TORCH_API MPSDevice {
*/
bool isMacOS13Plus(MacOSVersion version) const;
/**
* Returns device name
*/
std::string getName() const;
/**
* Returns number of GPU cores.
* 1 Core = 16 ExecutionUnit x 8 ALU x 24 threads
*/
unsigned getCoreCount() const;
~MPSDevice();
private:

View File

@ -85,10 +85,36 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
}
}
std::string MPSDevice::getName() const {
@autoreleasepool {
return [[_mtl_device name] UTF8String];
}
}
unsigned MPSDevice::getCoreCount() const {
io_iterator_t iterator = 0;
io_registry_entry_t entry = 0;
int core_count = 0;
auto matchingDict = IOServiceMatching("AGXAccelerator");
TORCH_INTERNAL_ASSERT(matchingDict, "Failed to create matching dict");
const auto status = IOServiceGetMatchingServices(kIOMainPortDefault, matchingDict, &iterator);
TORCH_INTERNAL_ASSERT(status == KERN_SUCCESS);
while ((entry = IOIteratorNext(iterator)) != 0) {
auto property = IORegistryEntryCreateCFProperty(entry, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
auto found = CFNumberGetValue(static_cast<CFNumberRef>(property), kCFNumberIntType, &core_count);
CFRelease(property);
IOObjectRelease(entry);
if (found) {
break;
}
}
IOObjectRelease(iterator);
return core_count;
}
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
return getIMPSAllocator(useSharedAllocator);
}
bool is_available() {
return MPSDevice::getInstance()->device() != nil;
}

View File

@ -1979,7 +1979,9 @@ def _mtia_resetPeakMemoryStats(device: _int) -> None: ...
# Defined in torch/csrc/mps/Module.cpp
def _mps_deviceSynchronize() -> None: ...
def _mps_get_core_count() -> _int: ...
def _mps_get_default_generator() -> Generator: ...
def _mps_get_name() -> _str: ...
def _mps_emptyCache() -> None: ...
def _mps_setMemoryFraction(fraction: _float) -> None: ...
def _mps_currentAllocatedMemory() -> _int: ...

View File

@ -17,6 +17,7 @@ specialized implementations for each hardware backend's unique features.
import inspect
import time
from collections import namedtuple
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
@ -544,8 +545,10 @@ class MpsInterface(DeviceInterface):
class Worker:
@staticmethod
def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]:
return {}
def get_device_properties(device: torch.types.Device = None) -> Any:
return namedtuple("MPSProperties", ["multi_processor_count"])(
torch.backends.mps.get_core_count() # type: ignore[arg-type]
)
@staticmethod
def current_device() -> int:

View File

@ -153,9 +153,6 @@ class DeviceProperties(typing.NamedTuple):
except AttributeError:
if device_type == "xpu":
multi_processor_count = props.gpu_subslice_count
elif device_type == "mps":
# TODO: Fetch the actual value from ioreg
multi_processor_count = 8
elif device_type == "mtia":
multi_processor_count = 64
else:

View File

@ -5,7 +5,14 @@ import torch
from torch.library import Library as _Library
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
__all__ = [
"get_core_count",
"get_name",
"is_built",
"is_available",
"is_macos13_or_newer",
"is_macos_or_newer",
]
def is_built() -> bool:
@ -36,6 +43,23 @@ def is_macos13_or_newer(minor: int = 0) -> bool:
return torch._C._mps_is_on_macos_or_newer(13, minor)
@_lru_cache
def get_name() -> str:
r"""Return Metal device name"""
return torch._C._mps_get_name()
@_lru_cache
def get_core_count() -> int:
r"""Return GPU core count.
According to the documentation, one core is comprised of 16 Execution Units.
One execution Unit has 8 ALUs.
And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
"""
return torch._C._mps_get_core_count()
_lib: Optional[_Library] = None

View File

@ -20,7 +20,6 @@
#include <ATen/Parallel.h>
#include <ATen/Utils.h>
#include <ATen/core/Vitals.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <ATen/dlpack.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/ForeachUtils.h>

View File

@ -501,6 +501,12 @@ void initModule(PyObject* module) {
at::mps::getMPSProfiler().startCapture(fileName);
});
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
m.def("_mps_get_name", []() {
return at::mps::MPSDevice::getInstance()->getName();
});
m.def("_mps_get_core_count", []() {
return at::mps::MPSDevice::getInstance()->getCoreCount();
});
}
#endif /* USE_MPS */