That fixes `index_put(..., accumulate=True)` for all dtypes
int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158179
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #158064, #158178
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator
Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint
Fixes https://github.com/pytorch/pytorch/issues/153560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158064
Approved by: https://github.com/dcci
They might have been slow on CUDA-11.3, but this version of CUDA is long gone. More fundamental underlying issue were linear complexity of the recursive polynomial definitions for higher order polynomials, for example see this loop from implementation of Chebyshev polynomial of the first kind
7081b8233a/aten/src/ATen/native/Math.h (L2969-L2973)
which were tested by `test_compare_cpu` using following values (as sample index 16)
7081b8233a/torch/testing/_internal/opinfo/core.py (L2079)
Luckily chebyshev polynomials for absolute values higher than 1 pretty quickly reach infinity, see below
```
python3 -c "import torch;print(torch.special.chebyshev_polynomial_v(torch.nextafter(torch.tensor(1.0), torch.tensor(2.0)), torch.tensor(1e6)))"
tensor(nan)
```
Which is not the case for Laguerre polynomials, but it's probably fine to just limit it to 1e7
Before
```
$ PYTORCH_TEST_WITH_SLOW=1 python test_ops.py -k chebyshev_polynomial_
ssssssss..ssssss..ssssss..ssssssssssssssssssssss..ssssss/home/ubuntu/py3.10-nightly/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.)
return torch._C._get_cublas_allow_tf32()
....ssssssssssss..ssssss..ssssss............ssssssssssssssssssssssssssssssssssss..ssssssssssssss..ssssss..ssssssssssssssssssssssssssssss..ssssss....ssssssssssss..ssssss..ssssss............ssssssssssssssssssssssssssssssssssss..ssssss..ssssssssssssss..ssssss..ssssss..ssssssssssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssssssssssss
----------------------------------------------------------------------
Ran 432 tests in 8.575s
OK (skipped=344)
```
After
```
$ PYTORCH_TEST_WITH_SLOW=1 python test_ops.py -k chebyshev_polynomial_
ssssssss........................ssssssssssssssss......../home/ubuntu/pytorch/torch/backends/cuda/__init__.py:131: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/ubuntu/pytorch/aten/src/ATen/Context.cpp:78.)
return torch._C._get_cublas_allow_tf32()
........................................................................................xxxxxxxx................ssssssssssssssssssssssss........................................................................................................ssssssss........................ssssssss........................................................................................ssssssss
----------------------------------------------------------------------
Ran 432 tests in 45.580s
OK (skipped=72, expected failures=8)
```
Fixes https://github.com/pytorch/pytorch/issues/79528
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157464
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157488
Introduce `c10:🤘:remainder` and call it from both inductor and eager implementation, with integer specialization, which should make it much faster than before, while still compliant with Python way of rounding up negative numbers.
This allows one to remove complex type detection logic from mps codegen and rely on Metal(C++) type system to figure out input and output types.
This fixes compilation of something like
```python
@torch.compile
def f(x, y):
return x[y % 5]
```
which beforehand failed to compile with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
#include <c10/metal/utils.h>
kernel void generated_kernel(
device float* out_ptr0,
constant long* in_ptr0,
constant float* in_ptr1,
uint xindex [[thread_position_in_grid]]
) {
int x0 = xindex;
auto tmp0 = in_ptr0[x0];
auto tmp1 = 12;
auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1));
auto tmp3 = 1024;
auto tmp4 = static_cast<long>(tmp3);
auto tmp5 = tmp2 + tmp4;
auto tmp6 = tmp2 < 0;
auto tmp7 = tmp6 ? tmp5 : tmp2;
if ((tmp7 < 0) && (tmp7 > 1024)) return;
auto tmp9 = in_ptr1[tmp7];
out_ptr0[x0] = static_cast<float>(tmp9);
}
with program_source:372:28: error: array subscript is not an integer
auto tmp9 = in_ptr1[tmp7];
^~~~~
```
This fixes fail_to_compile for GPT2ForSequenceClassification Huggingface model using `transformers==4.44.2`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155891
Approved by: https://github.com/manuelcandales
Implements the forward and backward hardshrink operators as Metal kernels.
In order to support the lambda parameter, we extend the `exec_unary_kernel` and `exec_binary_kernel` methods. Now they take an optional Scalar and an optional ScalarType argument. When the optional ScalarType is provided, it overrides the type of the Scalar.
We add a new `REGISTER_UNARY_ALPHA_OP` macro, and modify the existing `REGISTER_BINARY_ALPHA_OP` to support the new feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155304
Approved by: https://github.com/malfet
Metal arguments must be 8 bytes aliged (or may be 16 bytes), so running
any strided (or typecasted) binary op with MTL_DEBUG_LAYER leads to
exception
```
% MTL_DEBUG_LAYER=1 python3 ../test/test_mps.py -v -k test_output_match_add
2025-06-05 15:41:34.201 Python[86653:16826825] Metal API Validation Enabled
test_output_match_add_mps_bfloat16 (__main__.TestConsistencyMPS.test_output_match_add_mps_bfloat16) ...
validateComputeFunctionArguments:1083: failed assertion `Compute Function(add_strided_bfloat_bfloat): argument ndim[0] from buffer(7) with offset(0) and length(12) has space for 12 bytes, but argument has a length(16).'
zsh: abort MTL_DEBUG_LAYER=1 python3 ../test/test_mps.py -v -k test_output_match_add
```
Extend it to 4 elements and pass output dtype, which will be used by
binary_op later on anyway
Test plan: Run abovementioned command with `MTL_DEBUG_LAYER=1` and make
sure everything passes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155272
Approved by: https://github.com/angelayi, https://github.com/dcci, https://github.com/cyyever
This accomplishes following:
- Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer)
- Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph)
- Eliminates need for several correctness workarounds
Fixes https://github.com/pytorch/pytorch/issues/154171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280
Approved by: https://github.com/dcci
ghstack dependencies: #154275, #154290
What initially supposed to be a very straightforward change resulted in small refactor of binary op tensor generators when invoked for mixed dtype, which surfaced via `test_output_grad_match_sinc_mps_float16` test failure.
If operands are of different dtype (in particular float16 tensor and float32 scalar), one must perform an operation with `opmath_t` (or `TensorIterator::common_dtype()`) precision, rather than casting both operands to output dtype and performing it then, which can be demonstrated via the following example:
```
>>> torch.tensor([-1.8633, 6.2031, -2.2500, -3.3926, 8.5938, 5.9766], dtype=torch.half).mul(torch.pi)
tensor([ -5.8555, 19.4844, -7.0703, -10.6562, 27.0000, 18.7812],
dtype=torch.float16)
>>> torch.tensor([-1.8633, 6.2031, -2.2500, -3.3926, 8.5938, 5.9766], dtype=torch.half).mul(torch.tensor(torch.pi, dtype=torch.float16))
tensor([ -5.8516, 19.4844, -7.0664, -10.6562, 26.9844, 18.7656],
dtype=torch.float16)
```
Solve this problem for now, but introducing `REGISTER_OPMATH_BINARY_OP` that indicates that operands must be cast to opmath_t, before performing the computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152515
Approved by: https://github.com/Skylion007, https://github.com/kulinseth, https://github.com/dcci
ghstack dependencies: #152663
First of all, by extending `c10:🤘:cast_to` to work correctly with complex dtypes, by introducing two more specializations: one that casts complex to scalar, and another that casts scalar to complex (as default metal typecast will turn `float x` into `float2(x, x)`)
Add ComplexHalf and ComplexFloat enum values to `c10:🤘:ScalarTypes` and handle them in `val_at_offs(ptr, offs, type)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152504
Approved by: https://github.com/dcci
ghstack dependencies: #152443, #152466, #152479
By reusing `c10/metal/atomic.h`
This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support
TODOs:
- Get rid of indexing kernel and compute it directly when kernel is run
- Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch
- Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869
Approved by: https://github.com/Skylion007, https://github.com/dcci
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions
Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512 | 202.1 | 131.8 |
| 1024x1024 | 780.6 | 176.9 |
| 2048x2048 | 1423.4 | 339.9 |
| 4096x4097 | 2982.2 | 1047.2 |
Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size
Extend `c10:🤘:max` to accept different dtypes
Together this fixes `test_large_broadcast_reduction`
TODO:
- Explore different threadgroup_sizes for best perf
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150246
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.
Been dreaded to write this PR for a while, as it end up to be pretty bulky:
- Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
- Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
- Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
- Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op
Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)
| | dense-dense | transp-transp | dense-transp | transp-dense | dense-scalar | dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
| fmax (torch.float16, torch.float16) | 160.9 | 159.9 | 270.5 | 270.9 | 236.6 | 293.0
| fmax (torch.float32, torch.float32) | 176.9 | 171.0 | 273.7 | 293.5 | 242.6 | 294.2
| fmax (torch.float32, torch.float16) | 171.4 | 170.9 | 283.6 | 303.0 | 253.7 | 302.3
| add (torch.float16, torch.float16) | 218.0 | 223.6 | 221.0 | 222.0 | 214.9 | 218.3
| add (torch.float32, torch.float32) | 227.4 | 233.9 | 228.8 | 231.9 | 218.9 | 221.4
| add (torch.float32, torch.float16) | 226.1 | 227.5 | 227.5 | 226.9 | 177.0 | 190.8
TODOS:
- Include input and output dtype in non-cast kernel name
- Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo
Fixes https://github.com/pytorch/pytorch/issues/149951
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974
Approved by: https://github.com/manuelcandales