mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add API to query GPU core count (#160414)
Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10` ``` % python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())" Apple M1 Pro 16 % system_profiler SPDisplaysDataType|head -n10 Graphics/Displays: Apple M1 Pro: Chipset Model: Apple M1 Pro Type: GPU Bus: Built-In Total Number of Cores: 16 Vendor: Apple (0x106b) Metal Support: Metal 3 ``` This would significantly improve occupancy for torch.compile generated kernels Pull Request resolved: https://github.com/pytorch/pytorch/pull/160414 Approved by: https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
50a8c11875
commit
a06ec54d40
@ -1196,7 +1196,7 @@ if(APPLE)
|
||||
string(
|
||||
APPEND
|
||||
CMAKE_SHARED_LINKER_FLAGS
|
||||
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal"
|
||||
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal -weak_framework IOKit"
|
||||
)
|
||||
# To suppress MPSGraph availability warnings
|
||||
append_cxx_flag_if_supported("-Wno-unguarded-availability-new"
|
||||
|
@ -55,6 +55,17 @@ class TORCH_API MPSDevice {
|
||||
*/
|
||||
bool isMacOS13Plus(MacOSVersion version) const;
|
||||
|
||||
/**
|
||||
* Returns device name
|
||||
*/
|
||||
std::string getName() const;
|
||||
|
||||
/**
|
||||
* Returns number of GPU cores.
|
||||
* 1 Core = 16 ExecutionUnit x 8 ALU x 24 threads
|
||||
*/
|
||||
unsigned getCoreCount() const;
|
||||
|
||||
~MPSDevice();
|
||||
|
||||
private:
|
||||
|
@ -85,10 +85,36 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
}
|
||||
}
|
||||
|
||||
std::string MPSDevice::getName() const {
|
||||
@autoreleasepool {
|
||||
return [[_mtl_device name] UTF8String];
|
||||
}
|
||||
}
|
||||
|
||||
unsigned MPSDevice::getCoreCount() const {
|
||||
io_iterator_t iterator = 0;
|
||||
io_registry_entry_t entry = 0;
|
||||
int core_count = 0;
|
||||
auto matchingDict = IOServiceMatching("AGXAccelerator");
|
||||
TORCH_INTERNAL_ASSERT(matchingDict, "Failed to create matching dict");
|
||||
const auto status = IOServiceGetMatchingServices(kIOMainPortDefault, matchingDict, &iterator);
|
||||
TORCH_INTERNAL_ASSERT(status == KERN_SUCCESS);
|
||||
while ((entry = IOIteratorNext(iterator)) != 0) {
|
||||
auto property = IORegistryEntryCreateCFProperty(entry, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
|
||||
auto found = CFNumberGetValue(static_cast<CFNumberRef>(property), kCFNumberIntType, &core_count);
|
||||
CFRelease(property);
|
||||
IOObjectRelease(entry);
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
IOObjectRelease(iterator);
|
||||
return core_count;
|
||||
}
|
||||
|
||||
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
|
||||
return getIMPSAllocator(useSharedAllocator);
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return MPSDevice::getInstance()->device() != nil;
|
||||
}
|
||||
|
@ -1979,7 +1979,9 @@ def _mtia_resetPeakMemoryStats(device: _int) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/mps/Module.cpp
|
||||
def _mps_deviceSynchronize() -> None: ...
|
||||
def _mps_get_core_count() -> _int: ...
|
||||
def _mps_get_default_generator() -> Generator: ...
|
||||
def _mps_get_name() -> _str: ...
|
||||
def _mps_emptyCache() -> None: ...
|
||||
def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
||||
def _mps_currentAllocatedMemory() -> _int: ...
|
||||
|
@ -17,6 +17,7 @@ specialized implementations for each hardware backend's unique features.
|
||||
|
||||
import inspect
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
@ -544,8 +545,10 @@ class MpsInterface(DeviceInterface):
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]:
|
||||
return {}
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
return namedtuple("MPSProperties", ["multi_processor_count"])(
|
||||
torch.backends.mps.get_core_count() # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def current_device() -> int:
|
||||
|
@ -153,9 +153,6 @@ class DeviceProperties(typing.NamedTuple):
|
||||
except AttributeError:
|
||||
if device_type == "xpu":
|
||||
multi_processor_count = props.gpu_subslice_count
|
||||
elif device_type == "mps":
|
||||
# TODO: Fetch the actual value from ioreg
|
||||
multi_processor_count = 8
|
||||
elif device_type == "mtia":
|
||||
multi_processor_count = 64
|
||||
else:
|
||||
|
@ -5,7 +5,14 @@ import torch
|
||||
from torch.library import Library as _Library
|
||||
|
||||
|
||||
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
|
||||
__all__ = [
|
||||
"get_core_count",
|
||||
"get_name",
|
||||
"is_built",
|
||||
"is_available",
|
||||
"is_macos13_or_newer",
|
||||
"is_macos_or_newer",
|
||||
]
|
||||
|
||||
|
||||
def is_built() -> bool:
|
||||
@ -36,6 +43,23 @@ def is_macos13_or_newer(minor: int = 0) -> bool:
|
||||
return torch._C._mps_is_on_macos_or_newer(13, minor)
|
||||
|
||||
|
||||
@_lru_cache
|
||||
def get_name() -> str:
|
||||
r"""Return Metal device name"""
|
||||
return torch._C._mps_get_name()
|
||||
|
||||
|
||||
@_lru_cache
|
||||
def get_core_count() -> int:
|
||||
r"""Return GPU core count.
|
||||
|
||||
According to the documentation, one core is comprised of 16 Execution Units.
|
||||
One execution Unit has 8 ALUs.
|
||||
And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
|
||||
"""
|
||||
return torch._C._mps_get_core_count()
|
||||
|
||||
|
||||
_lib: Optional[_Library] = None
|
||||
|
||||
|
||||
|
@ -20,7 +20,6 @@
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/core/Vitals.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <ATen/dlpack.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
|
@ -501,6 +501,12 @@ void initModule(PyObject* module) {
|
||||
at::mps::getMPSProfiler().startCapture(fileName);
|
||||
});
|
||||
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
|
||||
m.def("_mps_get_name", []() {
|
||||
return at::mps::MPSDevice::getInstance()->getName();
|
||||
});
|
||||
m.def("_mps_get_core_count", []() {
|
||||
return at::mps::MPSDevice::getInstance()->getCoreCount();
|
||||
});
|
||||
}
|
||||
#endif /* USE_MPS */
|
||||
|
||||
|
Reference in New Issue
Block a user