Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service
Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor
Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10`
```
% python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"
Apple M1 Pro 16
% system_profiler SPDisplaysDataType|head -n10
Graphics/Displays:
Apple M1 Pro:
Chipset Model: Apple M1 Pro
Type: GPU
Bus: Built-In
Total Number of Cores: 16
Vendor: Apple (0x106b)
Metal Support: Metal 3
```
This would significantly improve occupancy for torch.compile generated kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160414
Approved by: https://github.com/dcci
To workaround limitation of 32-arguments per kernel and being able to eventually compile something like
```python
import torch
def foo(*args):
rc = torch.empty_like(args[0])
for arg in args:
rc += arg
return rc
tensors = torch.rand(100, 32, device='mps').unbind(0)
print(torch.compile(foo)(*tensors))
```
For now, introduce `at::native:🤘:get_tensor_gpu_address` and use it from both C++ test and compile_shader to convert list of tensors to list of pointers valid on GPU.
Initially this binding were done via `id< MTLArgumentEncoder>`, but according to [Improving CPU Performance by Using Argument Buffers](https://developer.apple.com/documentation/metal/improving-cpu-performance-by-using-argument-buffers?language=objc#Encode-Resources-into-Argument-Buffers) article, this is not necessary when targeting Tier2-only devices (which is true of all devices on MacOS-13 or newer):
> To directly encode the argument buffer resources on these Tier 2 devices, write the [MTLBuffer](https://developer.apple.com/documentation/metal/mtlbuffer?language=objc).[gpuAddress](https://developer.apple.com/documentation/metal/mtlbuffer/gpuaddress?language=objc) property — and for other resource types (samplers, textures, and acceleration structures), the [gpuResourceID](https://developer.apple.com/documentation/metal/mtlcomputepipelinestate/gpuresourceid?language=objc) property — into the corresponding structure member. To encode offsets, treat these property values as uint64 types and add the offset to them.
Add both C++ and PyThon unittests that validate that this works.
Please note, that using either ArgumentEncoder or directly encoding the data does not guarantee buffer will not be freed until shader execution is complete. On the other hand, this should already be guaranteed by MPSCachingAllocator that would only free the memory after all streams completed its execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150780
Approved by: https://github.com/dcci
Enables clang-tidy rule [`misc-use-internal-linkage`](https://clang.llvm.org/extra/clang-tidy/checks/misc/use-internal-linkage.html). This new check was introduced in Clang-Tidy 18 and is available due to recent update of Clang-Tidy 19.
The check marks functions and variables used only in the translation unit as static. Therefore undesired symbols are not leaked into other units, more link time optimisations are possible and the resulting binaries may be smaller.
The detected violations were mostly fixed by using static. In other cases, the symbols were indeed consumed by others files, then their declaring headers were included. Still some declarations were wrong and have been fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148948
Approved by: https://github.com/Skylion007
I.e. when `MTL_CAPTURE_ENABLED` environment variable is set to 1, one should be able to invoke wrap the code with `torch.mps.profiler.capture_metal` to generate gputrace for shaders invoked inside the context manager.
For example, code below:
```python
import torch
import os
def foo(x):
return x[:,::2].sin() + x[:, 1::2].cos()
if __name__ == "__main__":
os.environ["MTL_CAPTURE_ENABLED"] = "1"
x = torch.rand(32, 1024, device="mps")
with torch.mps.profiler.metal_capture("compiled_shader"):
torch.compile(foo)(x)
```
should capture the execution of a `torch.compile` generated shader
<img width="734" alt="image" src="https://github.com/user-attachments/assets/718ff64e-103b-4b11-b66c-c89cfc770b5d" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144561
Approved by: https://github.com/manuelcandales
ghstack dependencies: #144559, #144560
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
x[idx] += idx;
}
")
m.foo(x)
```
And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
x[idx] += idx;
}
")
m.foo(x)
```
And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
Prerequisite for adding more complex type support and FFT operation
Check using `conjugateWithTensor:name:` selector defined as follows
```objc
/// Returns the complex conjugate of the input tensor elements.
///
/// - Parameters:
/// - tensor: The input tensor.
/// - name: An optional string which serves as an identifier for the operation..
/// - Returns: A valid `MPSGraphTensor` object containing the elementwise result of the applied operation.
-(MPSGraphTensor *) conjugateWithTensor:(MPSGraphTensor *) tensor
name:(NSString * _Nullable) name
MPS_AVAILABLE_STARTING(macos(14.0), ios(17.0), tvos(17.0))
MPS_SWIFT_NAME( conjugate(tensor:name:) );
```
- Rename `isOnMacOS13orNewer(unsigned minor)` hook to `isOnMacOSorNewer(major, minor)`
- Replace `torch._C.__mps_is_on_macos_13_or_newer` with `torch._C._mps_is_on_macos_or_newer`
- Add `torch.backends.mps.is_macos_or_newer` public API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115512
Approved by: https://github.com/albanD
- Implement `MPSEventPool` to recycle events.
- Implement python bindings with `torch.mps.Event` class using the MPSEventPool backend. The current member functions of the Event class are `record()`, `wait()`, `synchronize()`, `query()`, and `elapsed_time()`.
- Add API to measure elapsed time between two event recordings.
- Added documentation for Event class to `mps.rst`.
- Added test case to `test_mps.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102121
Approved by: https://github.com/albanD, https://github.com/kulinseth
Will be needed if one wants to make accurate XFAIL validation
I.e. `torch.backends.mps.is_macos13_or_newer()` will return True if PyTorch is running on MacOS 13.0 or newer, `torch.backends.mps.is_macos13_or_newer(1)` will return True if running on MacOS 13.1 or newer and `torch.backends.mps.is_macos13_or_newer(2)` will return True if running on MacOS 13.2 or newer
Do not use 13.3 check as `@available` does not really work for shared libraries
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95065
Approved by: https://github.com/albanD
- To check for Memory Leaks in `test_mps.py`, set the env-variable `PYTORCH_TEST_MPS_MEM_LEAK_CHECK=1` when running test_mps.py (used CUDA code as reference).
- Added support for the following new python interfaces in MPS module:
`torch.mps.[empty_cache(), set_per_process_memory_fraction(), current_allocated_memory(), driver_allocated_memory()]`
- Renamed `_is_mps_on_macos_13_or_newer()` to `_mps_is_on_macos_13_or_newer()`, and `_is_mps_available()` to `_mps_is_available()` to be consistent in naming with prefix `_mps`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94646
Approved by: https://github.com/malfet
- This PR is a prerequisite for the upcoming Memory Leak Detection PR.
- Enable global manual seeding via `torch.manual_seed()` + test case
- Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case
- Enable the following python interfaces for MPS:
`torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]`
- Added some test cases in test_mps.py
- Added `mps.rst` to document the `torch.mps` module.
- Fixed the failure with `test_public_bindings.py`
Description of new files added:
- `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`.
- `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94417
Approved by: https://github.com/albanD
- This PR is a prerequisite for the upcoming Memory Leak Detection PR.
- Enable global manual seeding via `torch.manual_seed()` + test case
- Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case
- Enable the following python interfaces for MPS:
`torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]`
- Added some test cases in test_mps.py
- Added `mps.rst` to document the `torch.mps` module.
- Fixed the failure with `test_public_bindings.py`
Description of new files added:
- `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`.
- `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94417
Approved by: https://github.com/albanD