Addressing #154890
Not really a proper fix but at least it's more informative than the current crash.
For a more long term solution I'm testing if we can use the TopK API released in MacOS14 as it does not have the same MPSScan op issue that the Sort and ArgSort are hitting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155475
Approved by: https://github.com/kulinseth
Fixes#154615
Enables using ConvTranspose3D since it seems support exists both on MacOS 14 and 15.
For the half dtypes the discrepancy of CPU and GPU implementations is too large to conclude whether there is a bug in the implementation or not without a more rigorous study on what bounds are there to the expected error. So they are left unsupported for now and an assert is added to notify the user if the op is called with fp16 or bf16 inputs.
Tests for ConvTranspose3D were enabled for the supported data types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154696
Approved by: https://github.com/malfet
This accomplishes following:
- Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer)
- Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph)
- Eliminates need for several correctness workarounds
Fixes https://github.com/pytorch/pytorch/issues/154171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280
Approved by: https://github.com/dcci
ghstack dependencies: #154275, #154290
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op, but needs a small tweak to handle float scalars
Make sure that `out` is resized to an expected value in `isin_Tensor_Tensor_out_mps`
Add unittests to validate that, but skip them on MacOS-13, where MPS op just returns garbage
Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154010
Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/manuelcandales
ghstack dependencies: #153970, #153971, #153997
Current implementation causes silent correction problem with torch.compile when someone tries to `torch.compile` function where one of the arguments is say `np.exp(.3)`, which will be represented as torch.float64 scalar tensor
Add regssion test for this behavior
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153582
Approved by: https://github.com/dcci
By computing matmuls of only one random non-zero batch on CPU
This reduces test runtime from 11 minutes to 14 sec
```
% python3 test/test_mps.py -v -k test_large_bmm_
test_large_bmm_bfloat16 (__main__.TestMPS.test_large_bmm_bfloat16) ... ok
test_large_bmm_float16 (__main__.TestMPS.test_large_bmm_float16) ... ok
----------------------------------------------------------------------
Ran 2 tests in 27.495s
```
TODO: Compute it over two slices when https://github.com/pytorch/pytorch/issues/153560 is fixed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153562
Approved by: https://github.com/Skylion007, https://github.com/clee2000
By implementing `div_floor` and `div_trunc` . Do not mark `div_trunc` as OPMATH, to align following output with CPU(if division is performed in fp32, than result will be truncated to 25
```
import torch
print(torch.tensor([[-7.4688, -3.1289]], dtype=torch.float16,device="cpu").div(torch.tensor([-0.2988, -0.8789], dtype=torch.bfloat16,device="cpu"), rounding_mode="trunc"))
tensor([[24., 3.]])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152758
Approved by: https://github.com/dcci
ghstack dependencies: #152663, #152515, #152737, #152743
Fixes#152344
Leak seems to be on the MPS Graph side, even though there is an identity tensor it seems like it's no longer enough to bypass the SDPA sequence which seems to leak memory.
Even adding 0.0f seems to be optimized to be ignored and still take the sdpa sequence(that's the reason for adding 1e-20)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152371
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Implements layernorm forward pass as a metal kernel instead of MPSGraph ops. Speed ups are indicated on the chart below:

Script for generating times, need to build torch with old/new codebase and then run this with different file name indicated at the end of the script
```python
import csv
import time
import numpy as np
import torch
import torch.nn.functional as F
matrix_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
batch_sizes = [1]
elementwise_affine = [False, True]
num_runs = 50
warmup_runs = 3
def create_input_tensor(n, batch_size):
torch.manual_seed(42)
return torch.randn(batch_size, n, dtype=torch.float32)
def run_layer_norm(A, normalized_shape, elementwise_affine):
torch.mps.synchronize()
start = time.perf_counter()
out = F.layer_norm(A, normalized_shape)
torch.mps.synchronize()
end = time.perf_counter()
return out, end - start
results = {"N": [], "elementwise_affine": [], "batch_size": [], "mean_time": [], "std_time": []}
for el_aff in elementwise_affine:
for n in matrix_sizes:
for batch_size in batch_sizes:
print(f"\nBenchmarking LayerNorm for input size N={n}, batch_size={batch_size}, elementwise_affine={el_aff}")
try:
A_cpu = create_input_tensor(n, batch_size)
A_mps = A_cpu.to("mps")
normalized_shape = (n,)
for _ in range(warmup_runs):
_, _ = run_layer_norm(A_mps, normalized_shape, el_aff)
times = []
for _ in range(num_runs):
_, t = run_layer_norm(A_mps, normalized_shape, el_aff)
times.append(t)
mean_time = np.mean(times)
std_time = np.std(times)
results["N"].append(n)
results["elementwise_affine"].append(el_aff)
results["batch_size"].append(batch_size)
results["mean_time"].append(mean_time)
results["std_time"].append(std_time)
print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
except RuntimeError as e:
print(f"Error for N={n}, batch_size={batch_size}: {e}")
continue
with open("layernorm_benchmark_times_new.csv", "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["N", "elementwise_affine", "batch_size", "mean_time", "std_time"])
for i in range(len(results["N"])):
writer.writerow(
[
results["N"][i],
results["elementwise_affine"][i],
results["batch_size"][i],
results["mean_time"][i],
results["std_time"][i],
]
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152010
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
By reusing `c10/metal/atomic.h`
This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support
TODOs:
- Get rid of indexing kernel and compute it directly when kernel is run
- Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch
- Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869
Approved by: https://github.com/Skylion007, https://github.com/dcci
To follow pattern set by CPU and CUDA impls: define common_dtype and optionally casts `elements` and `test_elements` to common dtype if needed
- Add regression test, though skip it on MacOS-13, as `isin` seems to produce garbage there even for same dtypes
```
>>> import torch
>>> x=torch.arange(4.0, device='mps')
>>> y=torch.arange(1.0, 3.0, device='mps')
>>> x, y, torch.isin(x, y), torch.isin(y, x)
(tensor([0., 1., 2., 3.], device='mps:0'), tensor([1., 2.], device='mps:0'), tensor([False, True, False, False], device='mps:0'), tensor([False, False], device='mps:0'))
>>> torch.__version__
'2.6.0'
```
- Cleanup code a bit
Fixes https://github.com/pytorch/pytorch/issues/151443
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151600
Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/kulinseth
To workaround limitation of 32-arguments per kernel and being able to eventually compile something like
```python
import torch
def foo(*args):
rc = torch.empty_like(args[0])
for arg in args:
rc += arg
return rc
tensors = torch.rand(100, 32, device='mps').unbind(0)
print(torch.compile(foo)(*tensors))
```
For now, introduce `at::native:🤘:get_tensor_gpu_address` and use it from both C++ test and compile_shader to convert list of tensors to list of pointers valid on GPU.
Initially this binding were done via `id< MTLArgumentEncoder>`, but according to [Improving CPU Performance by Using Argument Buffers](https://developer.apple.com/documentation/metal/improving-cpu-performance-by-using-argument-buffers?language=objc#Encode-Resources-into-Argument-Buffers) article, this is not necessary when targeting Tier2-only devices (which is true of all devices on MacOS-13 or newer):
> To directly encode the argument buffer resources on these Tier 2 devices, write the [MTLBuffer](https://developer.apple.com/documentation/metal/mtlbuffer?language=objc).[gpuAddress](https://developer.apple.com/documentation/metal/mtlbuffer/gpuaddress?language=objc) property — and for other resource types (samplers, textures, and acceleration structures), the [gpuResourceID](https://developer.apple.com/documentation/metal/mtlcomputepipelinestate/gpuresourceid?language=objc) property — into the corresponding structure member. To encode offsets, treat these property values as uint64 types and add the offset to them.
Add both C++ and PyThon unittests that validate that this works.
Please note, that using either ArgumentEncoder or directly encoding the data does not guarantee buffer will not be freed until shader execution is complete. On the other hand, this should already be guaranteed by MPSCachingAllocator that would only free the memory after all streams completed its execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150780
Approved by: https://github.com/dcci
Fixes#142397
Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests
Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.
By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`
This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions
Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512 | 202.1 | 131.8 |
| 1024x1024 | 780.6 | 176.9 |
| 2048x2048 | 1423.4 | 339.9 |
| 4096x4097 | 2982.2 | 1047.2 |
Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
- Distinguish between conjugated/non_conjugated inputs by appending conjugation to the operator key
- For matmul or dot, add `conjugateWithTensor:name:` calls before running the op
- Enable testing for conjugated ops by passing `include_conjugated_inputs` to opinfo
- Filter `include_conjugated_inputs` argument from `sample_inputs_window` (probably should have landed as separate PR)
- Preserve conj property when gathering the views, that fixes `cov` operator
Fixes https://github.com/pytorch/pytorch/issues/148156
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150157
Approved by: https://github.com/dcci
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.
Been dreaded to write this PR for a while, as it end up to be pretty bulky:
- Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
- Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
- Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
- Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op
Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)
| | dense-dense | transp-transp | dense-transp | transp-dense | dense-scalar | dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
| fmax (torch.float16, torch.float16) | 160.9 | 159.9 | 270.5 | 270.9 | 236.6 | 293.0
| fmax (torch.float32, torch.float32) | 176.9 | 171.0 | 273.7 | 293.5 | 242.6 | 294.2
| fmax (torch.float32, torch.float16) | 171.4 | 170.9 | 283.6 | 303.0 | 253.7 | 302.3
| add (torch.float16, torch.float16) | 218.0 | 223.6 | 221.0 | 222.0 | 214.9 | 218.3
| add (torch.float32, torch.float32) | 227.4 | 233.9 | 228.8 | 231.9 | 218.9 | 221.4
| add (torch.float32, torch.float16) | 226.1 | 227.5 | 227.5 | 226.9 | 177.0 | 190.8
TODOS:
- Include input and output dtype in non-cast kernel name
- Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo
Fixes https://github.com/pytorch/pytorch/issues/149951
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974
Approved by: https://github.com/manuelcandales
Implements nanmedian on MPS. This implementation only implements `torch.nanmedian(tensor)` without `keepdim` and `dim`
Will implement nanmedian with dim and keepdim in a followup
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149407
Approved by: https://github.com/malfet