110 Commits

Author SHA1 Message Date
ad67170c8b [MPS] sparse matmuls (#165232)
Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes:
https://github.com/pytorch/pytorch/issues/156540
https://github.com/pytorch/pytorch/issues/129842

Should be merged after:
https://github.com/pytorch/pytorch/pull/165102

To compare MPS and CPU, you can use this script:
```python
import torch
import time
import matplotlib.pyplot as plt

B, I, J, K = 8, 20000, 20000, 20000
num_iterations = 500

nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000]
speedups = []

for nnz in nnz_values:
    indices = torch.stack([
        torch.randint(0, B, (nnz,)),
        torch.randint(0, I, (nnz,)),
        torch.randint(0, J, (nnz,)),
    ])
    values = torch.rand(nnz)

    sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce()
    dense = torch.randn(B, J, 200, device="mps")

    t1 = time.time()
    for _ in range(num_iterations):
        result = torch.bmm(sparse, dense)
    torch.mps.synchronize()
    t2 = time.time()
    mps_time = (t2 - t1) / num_iterations

    sparse_cpu = sparse.cpu()
    dense_cpu = dense.cpu()
    t1 = time.time()
    for _ in range(num_iterations):
        result_cpu = torch.bmm(sparse_cpu, dense_cpu)
    t2 = time.time()
    cpu_time = (t2 - t1) / num_iterations

    speedup = cpu_time / mps_time
    speedups.append(speedup)
    print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x")

plt.figure(figsize=(10, 6))
plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8)
plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12)
plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12)
plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axhline(y=1, color='r', linestyle='--', alpha=0.5)
plt.xscale('log')
plt.tight_layout()
plt.show()
```

## Tested on M1 Pro
<img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232
Approved by: https://github.com/malfet
2025-10-18 09:04:42 +00:00
791eff96c8 [MPS] Add igamma/igammac ops (#161927)
Fixes #161725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161927
Approved by: https://github.com/malfet
2025-09-02 20:52:02 +00:00
28ccc9e724 [MPS] Extend index_put to complex types (#160159)
And delete confusing supported types check.
Move all pseudo atomic (but eventually consistent) ops to `c10/metal/atomic.h` header

Fixes https://github.com/pytorch/pytorch/issues/160034
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160159
Approved by: https://github.com/manuelcandales, https://github.com/dcci, https://github.com/Skylion007
2025-08-08 21:54:30 +00:00
e2a5c42e7e [BE][MPS] Build metal kernels of MacOS-14+ (#159733)
Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary.
Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic.

Part of https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159733
Approved by: https://github.com/dcci
ghstack dependencies: #159731, #159732
2025-08-03 20:53:58 +00:00
f946b25865 [MPS] Speedup argmax/argmin (#159524)
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
2025-07-31 16:18:32 +00:00
1293405c8d [MPS] Add simd_[arg][max|min] (#158990)
And add eager tests for those.
Re-implement `threadgroup_[max|min]` using those function as they are significantly faster (though much slower than eager, due to the arg part) than before, which could be verified by running the following script
```python
import itertools
import timeit
import torch
from torch.utils.benchmark import Compare, Measurement, Timer

def bench_unary_op(func, x, label) -> Measurement:
    sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
    t = Timer(
        stmt=f"f(x);{sync_cmd}",
        globals={"f": func, "x": x},
        language="python",
        timer=timeit.default_timer,
        sub_label=f"{func.__name__} ({str(x.dtype)})",
        description=label,
        env=torch.__version__,
    )
    return t.blocked_autorange()

def bench_reduction(
    reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
    rc = []

    # Bench 2D with reduction over dim=0
    def f(t):
        return reduction_func(t, dim=0)[0]

    f.__name__ = reduction_func.__name__
    f_c = torch.compile(f, dynamic=False, fullgraph=True)

    for size in (512, 1024, 2048, 4096):
        x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
        rc_c, rc_e = f(x), f_c(x)
        rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
        rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
        rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
    return rc

def main() -> None:
    #dtypes = [torch.float16, torch.float32, torch.bfloat16, torch.int32, torch.int64]
    dtypes = [torch.float32, torch.int32, torch.int64]

    # Profile reduction ops
    rc = []
    for op, dtype in itertools.product([torch.max], dtypes):
        rc.extend(bench_reduction(op, dtype=dtype))
    Compare(rc).print()

if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 2**16
    main()
```

Produces the following table before
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      297.3      |       531.6       |       394.1       |        2550.5       |       773.0       |        4904.7       |       3647.2      |        9682.0
      max (torch.int32)    |      297.8      |       359.2       |       387.7       |        1179.4       |       768.2       |        2175.0       |       3677.1      |        4495.9
      max (torch.int64)    |      278.7      |       541.4       |       410.2       |        2873.3       |       858.9       |        5620.4       |       6107.2      |       11176.1

Times are in microseconds (us).
```
And after
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      307.9      |       265.3       |       401.0       |        340.8        |       766.5       |        661.9        |       3463.5      |        2829.5
      max (torch.int32)    |      293.5      |       263.1       |       405.0       |        338.8        |       761.4       |        672.5        |       3050.0      |        2688.6
      max (torch.int64)    |      308.2      |       255.7       |       417.4       |        341.4        |       877.0       |        695.0        |       5812.2      |        5762.2

```

`argmax`/`argmin` are much tricker due to the nan-handling logic that need to be added there.

Also fixes `torch.max/min` compilation for half-precision types, added regression types for it.

This PR also introduces a bunch of helper functions, such as `simd_broadcast` that works for int64 and `c10:🤘:pair` template, which are used by `simd_argmax` to return both value and index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158990
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-07-30 21:57:25 +00:00
9ca080db87 [MPS] Extend atomic operations to all int types (#158179)
That fixes `index_put(..., accumulate=True)` for all dtypes

int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158179
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #158064, #158178
2025-07-14 04:25:05 +00:00
beed033b6e [MPS] Fix index_kernel for large tensors (#158064)
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes https://github.com/pytorch/pytorch/issues/153560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158064
Approved by: https://github.com/dcci
2025-07-11 22:35:44 +00:00
39a8f66d59 [BE] Use simdgroup_size constexpr (#157751)
Instead of every shader defining it separately, move it to `c10/metal/common.h`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157746
2025-07-08 03:46:20 +00:00
0b73f7c871 [EZ][BE] Move array def to c10/metal/common.h (#157746)
And use proper type aliasing instead of weird _ARRAY_NS

Also use `uint64_t` instead of `ulong`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157746
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-07-08 03:46:20 +00:00
a952956d05 Add isnan exit condition to special ops (#157464)
They might have been slow on CUDA-11.3, but this version of CUDA is long gone. More fundamental underlying issue were linear complexity of the recursive polynomial definitions for higher order polynomials, for example see this loop from implementation of Chebyshev polynomial of the first kind
7081b8233a/aten/src/ATen/native/Math.h (L2969-L2973)
which were tested by `test_compare_cpu` using following values (as sample index 16)
7081b8233a/torch/testing/_internal/opinfo/core.py (L2079)

Luckily chebyshev polynomials for absolute values higher than 1 pretty quickly reach infinity, see below
```
python3 -c "import torch;print(torch.special.chebyshev_polynomial_v(torch.nextafter(torch.tensor(1.0), torch.tensor(2.0)), torch.tensor(1e6)))"
tensor(nan)
```
Which is not the case for Laguerre polynomials, but it's probably fine to just limit it to 1e7

Before
```
$ PYTORCH_TEST_WITH_SLOW=1 python test_ops.py -k chebyshev_polynomial_
ssssssss..ssssss..ssssss..ssssssssssssssssssssss..ssssss/home/ubuntu/py3.10-nightly/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.)
  return torch._C._get_cublas_allow_tf32()
....ssssssssssss..ssssss..ssssss............ssssssssssssssssssssssssssssssssssss..ssssssssssssss..ssssss..ssssssssssssssssssssssssssssss..ssssss....ssssssssssss..ssssss..ssssss............ssssssssssssssssssssssssssssssssssss..ssssss..ssssssssssssss..ssssss..ssssss..ssssssssssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssssssssssss
----------------------------------------------------------------------
Ran 432 tests in 8.575s

OK (skipped=344)
```
After
```
$ PYTORCH_TEST_WITH_SLOW=1 python test_ops.py -k chebyshev_polynomial_
ssssssss........................ssssssssssssssss......../home/ubuntu/pytorch/torch/backends/cuda/__init__.py:131: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/ubuntu/pytorch/aten/src/ATen/Context.cpp:78.)
  return torch._C._get_cublas_allow_tf32()
........................................................................................xxxxxxxx................ssssssssssssssssssssssss........................................................................................................ssssssss........................ssssssss........................................................................................ssssssss
----------------------------------------------------------------------
Ran 432 tests in 45.580s

OK (skipped=72, expected failures=8)
```

Fixes https://github.com/pytorch/pytorch/issues/79528

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157464
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157488
2025-07-05 04:19:50 +00:00
ec816d73b4 [MPS] Add shifted_chebyshev_polynomial_[tuvw] (#157488)
For eager and inductor

As for all other chebyshev ops, logic is simply compiled from 94716db222/aten/src/ATen/native/cuda/Math.cuh (L2821)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157488
Approved by: https://github.com/dcci
2025-07-03 15:48:37 +00:00
b6276a425f Revert "[MPS] Add shifted_chebyshev_polynomial_[tuvw] (#157488)"
This reverts commit 9620994067b18e846a097d1e99af85ec2426ef0a.

Reverted https://github.com/pytorch/pytorch/pull/157488 on behalf of https://github.com/clee2000 due to caused slow test config to time out [GH job link](https://github.com/pytorch/pytorch/actions/runs/16037776972/job/45254574100) [HUD commit link](e124a0d88c) ([comment](https://github.com/pytorch/pytorch/pull/157464#issuecomment-3032676989))
2025-07-03 15:24:15 +00:00
9620994067 [MPS] Add shifted_chebyshev_polynomial_[tuvw] (#157488)
For eager and inductor

As for all other chebyshev ops, logic is simply compiled from 94716db222/aten/src/ATen/native/cuda/Math.cuh (L2821)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157488
Approved by: https://github.com/dcci
ghstack dependencies: #157464
2025-07-02 23:29:35 +00:00
402ae09e41 [BE] fix typos in c10/ (#156078)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156078
Approved by: https://github.com/malfet, https://github.com/cyyever
2025-06-18 10:24:44 +00:00
a4ea242edc [MPS] Implement scan metal kernels (#156100)
Implements metal kernels for scan operations:
- Migrates cumsum and cumprod from MPSGraph implementation to Metal.
- Fixes #154881
- Adds MPS backend support for cummin and cummax

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156100
Approved by: https://github.com/malfet
2025-06-17 17:44:22 +00:00
b6add8c8ba [MPSInductor] Fix remainder implementation for int types (#155891)
Introduce `c10:🤘:remainder` and call it from both inductor and eager implementation, with integer specialization, which should make it much faster than before, while still compliant with Python way of rounding up negative numbers.

This allows one to remove complex type detection logic from mps codegen and rely on Metal(C++) type system to figure out input and output types.

This fixes compilation of something like
```python
@torch.compile
def f(x, y):
    return x[y % 5]
```

which beforehand failed to compile with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant long* in_ptr0,
        constant float* in_ptr1,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = xindex;
        auto tmp0 = in_ptr0[x0];
        auto tmp1 = 12;
        auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1));
        auto tmp3 = 1024;
        auto tmp4 = static_cast<long>(tmp3);
        auto tmp5 = tmp2 + tmp4;
        auto tmp6 = tmp2 < 0;
        auto tmp7 = tmp6 ? tmp5 : tmp2;
        if ((tmp7 < 0) && (tmp7 > 1024)) return;
        auto tmp9 = in_ptr1[tmp7];
        out_ptr0[x0] = static_cast<float>(tmp9);
    }
 with program_source:372:28: error: array subscript is not an integer
        auto tmp9 = in_ptr1[tmp7];
                           ^~~~~
```

This fixes fail_to_compile for GPT2ForSequenceClassification Huggingface model using `transformers==4.44.2`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155891
Approved by: https://github.com/manuelcandales
2025-06-13 16:42:56 +00:00
48921721d8 [MPS] Fix binary builds (#155733)
Introduced by https://github.com/pytorch/pytorch/pull/155611

All functions in those headers must be static and inline
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155733
Approved by: https://github.com/seemethere, https://github.com/atalman
2025-06-11 22:55:33 +00:00
013cf1e330 [MPS] Move expm1 op to Metal (#155611)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155611
Approved by: https://github.com/malfet
2025-06-11 13:06:14 +00:00
0f47e76937 [MPS] Implement hardshrink metal kernel (#155304)
Implements the forward and backward hardshrink operators as Metal kernels.
In order to support the lambda parameter, we extend the `exec_unary_kernel`  and `exec_binary_kernel` methods. Now they take an optional Scalar and an optional ScalarType argument. When the optional ScalarType is provided, it overrides the type of the Scalar.
We add a new `REGISTER_UNARY_ALPHA_OP` macro, and modify the existing `REGISTER_BINARY_ALPHA_OP` to support the new feature.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155304
Approved by: https://github.com/malfet
2025-06-10 18:20:27 +00:00
f140fac8dc [MPS] Implement erfc (#155382)
And migrate `erf` to Metal kernel

Use `erf` approximations from https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/erf.h as previous approximation did not match the CPU implementation

After that, `erfc(x) := 1.0 - erf(x)`

Fixes https://github.com/pytorch/pytorch/issues/155337

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155382
Approved by: https://github.com/manuelcandales, https://github.com/dcci
2025-06-07 02:35:12 +00:00
231eb9902b [MPS][BE] Extend ndim_and_dtypes to 4 elements (#155272)
Metal arguments must be 8 bytes aliged (or may be 16 bytes), so running
any strided (or typecasted) binary op with MTL_DEBUG_LAYER leads to
exception
```
% MTL_DEBUG_LAYER=1 python3 ../test/test_mps.py -v -k test_output_match_add
2025-06-05 15:41:34.201 Python[86653:16826825] Metal API Validation Enabled
test_output_match_add_mps_bfloat16 (__main__.TestConsistencyMPS.test_output_match_add_mps_bfloat16) ...
validateComputeFunctionArguments:1083: failed assertion `Compute Function(add_strided_bfloat_bfloat): argument ndim[0] from buffer(7) with offset(0) and length(12) has space for 12 bytes, but argument has a length(16).'
zsh: abort      MTL_DEBUG_LAYER=1 python3 ../test/test_mps.py -v -k test_output_match_add
```

Extend it to 4 elements and pass output dtype, which will be used by
binary_op later on anyway

Test plan: Run abovementioned command with `MTL_DEBUG_LAYER=1` and make
sure everything passes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155272
Approved by: https://github.com/angelayi, https://github.com/dcci, https://github.com/cyyever
2025-06-06 14:20:21 +00:00
9cdce682a1 [MPS][BE] Reimplement log1p as Metal shader (#154936)
That should make it faster than MPSGraph implementation, but also
improves accuracy for small inputs, by using the algorithm described in [What Every Computer Scientist Should Know About Floating-Point Arithmetic](https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202), i.e. $log(1+x) = \frac{x * log(1+x)}{(1 + x) - 1}$ if $1 +x \neq 1$ else just $x$

Also tried using first 3 elements of Taylor series in Horner's form which also seems to work fine, i.e. $log(1+x) \approx x * (1 -x (\frac{1}{2} -  \frac{x}{3}))$

Replaced less accurate log1p implementation in `c10/metal/special_math.h` with generic one.

Parametrize and modify regression test to check for accuracy of small values

TODOs:
 - Do proper implementation for complex values as well, perhaps using 0408ba0a76/mlx/backend/metal/kernels/utils.h (L339)
 - May be implement it using Remez-like algorithm documented here 207f3b2b25/lib/msun/src/s_log1pf.c (L37)
 - Or use llvm's implementation from f393986b53/libclc/clc/lib/generic/math/clc_log1p.inc (L22)
 - Benchmark which algorithm is faster and delivers better accuracy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154936
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-06-03 14:10:13 +00:00
975bbc63db [MPS][BE] Move fmod/remainder to Metal ops (#154280)
This accomplishes following:
 - Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer)
 - Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph)
 - Eliminates need for several correctness workarounds

Fixes https://github.com/pytorch/pytorch/issues/154171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280
Approved by: https://github.com/dcci
ghstack dependencies: #154275, #154290
2025-05-24 01:45:33 +00:00
58dc80dff6 [MPSInductor] Fix indexing calculation (#153997)
By using `c10:🤘:floor_divie` primitive

Which fixes `test_flip_cat_mps` test, and makes `doctr_reco_predictor` and `doctr_det_predictor` pass accuracy checks (at least locally, scheduled a workflow dispatch to validate it in CI)

Before this change following script generated different compile and eager results
```python
import torch

def foo(unsqueeze, unsqueeze_1):
    cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
    view = torch.ops.aten.view.default(cat_1, [4])
    slice_5 = torch.ops.aten.slice.Tensor(view, 0, 0, 3)
    rev_1 = torch.ops.aten.flip.default(slice_5, [0])
    return rev_1

if __name__ == "__main__":
    x = torch.arange(1.0, 3.0, device='mps').reshape(2, 1)
    y = torch.arange(5.0, 7.0, device='mps').reshape(2, 1)

    rc, (kernel,) = torch._inductor.utils.run_and_get_kernels(torch.compile(foo), x, y)
    print(kernel)
    print("Compile: ", rc)
    print("Eager: ", foo(x, y))
```
After this change
```
'''
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant float* in_ptr0,
        constant float* in_ptr1,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = xindex;
        auto tmp6 = in_ptr0[1 + (c10:🤘:floor_divide((-1)*x0, 2))];
        auto tmp11 = in_ptr1[1 + (c10:🤘:floor_divide((-1)*x0, 2))];
        auto tmp0 = (2 + ((-1)*x0)) % (2);
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        auto tmp4 = 1;
        auto tmp5 = tmp1 < tmp4;
        auto tmp7 = tmp5 ? tmp6 : 0.0;
        auto tmp8 = tmp1 >= tmp4;
        auto tmp9 = 2;
        auto tmp10 = tmp1 < tmp9;
        auto tmp12 = tmp8 ? tmp11 : 0.0;
        auto tmp13 = tmp5 ? tmp7 : tmp12;
        out_ptr0[x0] = static_cast<float>(tmp13);
    }
'''
Compile:  tensor([2., 5., 1.], device='mps:0')
Eager:  tensor([2., 5., 1.], device='mps:0')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153997
Approved by: https://github.com/dcci
ghstack dependencies: #153970, #153971
2025-05-21 00:03:46 +00:00
e889937850 [MPS] Migrate div to Metal (#152743)
TODOs:
 - Verify accuracy of  `metal::dot` vs `x.x*x.x + y.y*y.y`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152743
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #152663, #152515, #152737
2025-05-04 00:56:19 +00:00
792736f9ac [BE][MPS] Pass alpha by reference (#152737)
As it's always a scalar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152737
Approved by: https://github.com/dcci
ghstack dependencies: #152663, #152515
2025-05-03 08:31:45 +00:00
34e9f0b5c6 [MPS] Migrate mul to TensorIterator (#152515)
What initially supposed to be a very straightforward change resulted in small refactor of binary op tensor generators when  invoked for mixed dtype, which surfaced via `test_output_grad_match_sinc_mps_float16` test failure.

If operands are of different dtype (in particular float16 tensor and float32 scalar), one must perform an operation with `opmath_t` (or `TensorIterator::common_dtype()`) precision, rather than casting both operands to output dtype and performing it then, which can be demonstrated via the following example:
```
>>> torch.tensor([-1.8633, 6.2031, -2.2500, -3.3926,  8.5938,  5.9766], dtype=torch.half).mul(torch.pi)
tensor([ -5.8555,  19.4844,  -7.0703, -10.6562,  27.0000,  18.7812],
       dtype=torch.float16)
>>> torch.tensor([-1.8633, 6.2031, -2.2500, -3.3926,  8.5938,  5.9766], dtype=torch.half).mul(torch.tensor(torch.pi, dtype=torch.float16))
tensor([ -5.8516,  19.4844,  -7.0664, -10.6562,  26.9844,  18.7656],
       dtype=torch.float16)
```

Solve this problem for now, but introducing `REGISTER_OPMATH_BINARY_OP` that indicates that operands must be cast to opmath_t, before performing the computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152515
Approved by: https://github.com/Skylion007, https://github.com/kulinseth, https://github.com/dcci
ghstack dependencies: #152663
2025-05-03 02:35:03 +00:00
a2c553cac6 [Metal] Extend typecasted op support to complex dtypes (#152504)
First of all, by extending `c10:🤘:cast_to` to work correctly with complex dtypes, by introducing two more specializations: one that casts complex to scalar, and another that casts scalar to complex (as default metal typecast will turn `float x` into `float2(x, x)`)

Add ComplexHalf and ComplexFloat enum values to `c10:🤘:ScalarTypes` and handle them in `val_at_offs(ptr, offs, type)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152504
Approved by: https://github.com/dcci
ghstack dependencies: #152443, #152466, #152479
2025-04-30 05:32:07 +00:00
9bfdf57572 [MPS][BE] Introduce c10:🤘:mul (#152466)
Which multiplies two arguments for either scalar or complex data types

This allows one to get rid of bunch of complex specialization in BinaryOps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152466
Approved by: https://github.com/dcci
ghstack dependencies: #152443
2025-04-30 04:45:47 +00:00
663bcb68ba Implement metal kernel for basic MPS arithmetic ops using TensorIterator (#147644)
Add metal kernels for add, subtract, & lerp ops using TensorIterator. Should help resolve: https://github.com/pytorch/pytorch/issues/143874
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147644
Approved by: https://github.com/malfet
2025-04-29 14:24:49 +00:00
e28864fc0f [MPS/inductor] Fix the approximation of polygamma for n == 0. (#152214)
Fixes #152205

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152214
Approved by: https://github.com/malfet
2025-04-25 22:42:45 +00:00
3aecf2dc52 [MPS] Extend index_put to half precision floats (#151869)
By reusing `c10/metal/atomic.h`
This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support

TODOs:
 - Get rid of indexing kernel and compute it directly when kernel is run
 - Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch
 - Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-04-22 22:00:08 +00:00
d778c92e16 [Metal][BE] Move atomic ops to c10/metal/atomic.h (#151868)
To be reused from indexing and MPSInductor implementaiton of atomic_add stores
Added wrapper for `metal::atomic<int>`(to be used by followup PR)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151868
Approved by: https://github.com/Skylion007
2025-04-22 14:11:29 +00:00
470132c6a1 [MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754
Approved by: https://github.com/malfet, https://github.com/jansel
2025-04-20 17:44:40 +00:00
9699cc3eb9 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 21:44:51 +00:00
7762bddd87 Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"
This reverts commit 71073caa00836c23e3fc7fcfe1d69b77ffb9d9c9.

Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
2025-04-12 20:27:48 +00:00
71073caa00 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 19:16:33 +00:00
397d37acc5 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 03:11:38 +00:00
77407b38a9 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 575f348965abe8ea428eba7098f67ec9764a7f9a.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to Linter fails again, landrace this time? ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798392241))
2025-04-12 02:22:22 +00:00
575f348965 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 00:46:01 +00:00
83f14c0b06 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 5edfb4c4fad1bb9504482d930a2540d22427d383.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
2025-04-12 00:21:14 +00:00
5edfb4c4fa [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-11 23:21:35 +00:00
7ac8186851 [MPSInductor] Speedup sum/prod reductions (#150566)
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512         | 202.1       | 131.8       |
| 1024x1024   |   780.6    | 176.9       |
| 2048x2048    |   1423.4       | 339.9      |
| 4096x4097    |    2982.2 | 1047.2      |

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
2025-04-05 02:47:27 +00:00
b48505a8a1 [MPS] Add support for hermite_polynomial_h. (#150279)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150279
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-03-31 23:30:19 +00:00
965784eb9b [MPSInductor] Specify max_total_threads_per_threadgroup (#150247)
When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size

Extend `c10:🤘:max` to accept different dtypes

Together this fixes `test_large_broadcast_reduction`

TODO:
  - Explore different threadgroup_sizes for best perf

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150246
2025-03-29 19:37:15 +00:00
52135db69a [BE] Fix signed/unsigned comparison warning (#150246)
One will see them only if compilation fails, but still
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150246
Approved by: https://github.com/cyyever, https://github.com/jansel
2025-03-29 15:12:42 +00:00
6aca002d82 [MPS] Add chebyshev_polynomial_[uvw] (#150060)
For both eager and inductor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150060
Approved by: https://github.com/dcci, https://github.com/jansel
2025-03-26 23:35:05 +00:00
3a8171efad [MPS] Preserve in/out dtypes in binary_op name (#150024)
To be consistient with unary op and avoid silent correctness problems if someone will try to invoke the op with unexpected out dtype
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150024
Approved by: https://github.com/dcci
2025-03-26 16:00:43 +00:00
de68ddc68e [MPS] Fix metal ops with different dtypes (#149974)
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.

Been dreaded to write this PR for a while, as it end up to be pretty bulky:
 - Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
 - Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
 - Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
 - Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op

Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)

|                                           |  dense-dense  |  transp-transp  |  dense-transp  |  transp-dense  |  dense-scalar  |  dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
|      fmax (torch.float16, torch.float16)  |     160.9     |      159.9      |     270.5      |     270.9      |     236.6      |     293.0
|      fmax (torch.float32, torch.float32)  |     176.9     |      171.0      |     273.7      |     293.5      |     242.6      |     294.2
|      fmax (torch.float32, torch.float16)  |     171.4     |      170.9      |     283.6      |     303.0      |     253.7      |     302.3
|      add (torch.float16, torch.float16)   |     218.0     |      223.6      |     221.0      |     222.0      |     214.9      |     218.3
|      add (torch.float32, torch.float32)   |     227.4     |      233.9      |     228.8      |     231.9      |     218.9      |     221.4
|      add (torch.float32, torch.float16)   |     226.1     |      227.5      |     227.5      |     226.9      |     177.0      |     190.8

TODOS:
 - Include input and output dtype in non-cast kernel name
 - Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo

Fixes https://github.com/pytorch/pytorch/issues/149951
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974
Approved by: https://github.com/manuelcandales
2025-03-26 07:03:21 +00:00