140 Commits

Author SHA1 Message Date
aea57b3aa3 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 16:06:36 +00:00
4412026949 Revert "AOTI MPS Shim Implementation (#163865)"
This reverts commit 874efa2d72d83b00894097130f18062ce331a265.

Reverted https://github.com/pytorch/pytorch/pull/163865 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/163865#issuecomment-3385196387))
2025-10-09 10:26:01 +00:00
874efa2d72 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 09:28:10 +00:00
ff2f319e6e [MPS] Fix conv layout handling (#162776)
What started as simple fix for `mps_convolution_backward_input` resulted in a pretty significant refactor/fixes:
- Updated `mps_conv_use_channels_last` to return channels last output if either input or weights are channels last
- Use the same primitive throughout `Convolution.mm` to determine wether output should be allocated in channels last format or not

But doing only those two, resulted in crash in `test_memory_format_nn_Conv2d_mps_float32`, when weights were backward, and bias is present:
```
% python -c "import torch;print(torch.nn.functional.conv2d(torch.rand(2, 4, 3, 4,device='mps'), torch.rand(5, 4, 3, 3,device='mps').to(memory_format=torch.channels_last), torch.rand(5,device='mps')))"
/AppleInternal/Library/BuildRoots/4~B5E4ugDCh2RsPWAjMEoPu8LC5w1yXEwd7XweDhg/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:3619: failed assertion `Error: MLIR pass manager failed'
zsh: abort      python -c
```

Which requires a more thorough redesign/cleanup, namely:
- Do not alter the layout based on MacOS version, but rather do additional copies on MacOS-14 if inputs/output or weight are in channels last format ( done by defining `std::optional<Tensor> output_c;` that contains a contiguous copy of the output tensor
- Introduced `input_suggested_layout` which is set to ChannelsLast if and only if input is channels last and is running on MacOS-15+
- Delete unused `memory_layout` and `group` arguments from `fill_depthwise_conv_desc`
- Fix bias broadcasting logic for channels last

As result, in addition to adding one more regression test this change removes `expectedFailures` from:
- `TestModule.test_memory_format` for `Conv2d`, `ConvTranspose2d`, `LazyConv1d`, `LazyConvTranspose1d`
- `test_require_stride_expanded_dynamic_shapes`
-  `test_mutable_custom_op_fixed_layout2` for MacOS-14

Fixes https://github.com/pytorch/pytorch/issues/161905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162776
Approved by: https://github.com/Skylion007
2025-09-25 23:41:34 +00:00
60b4791d08 [MPS] Fix compile linalg inv (#163452)
Fixes #161969

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163452
Approved by: https://github.com/Skylion007
2025-09-22 10:36:52 +00:00
1293405c8d [MPS] Add simd_[arg][max|min] (#158990)
And add eager tests for those.
Re-implement `threadgroup_[max|min]` using those function as they are significantly faster (though much slower than eager, due to the arg part) than before, which could be verified by running the following script
```python
import itertools
import timeit
import torch
from torch.utils.benchmark import Compare, Measurement, Timer

def bench_unary_op(func, x, label) -> Measurement:
    sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
    t = Timer(
        stmt=f"f(x);{sync_cmd}",
        globals={"f": func, "x": x},
        language="python",
        timer=timeit.default_timer,
        sub_label=f"{func.__name__} ({str(x.dtype)})",
        description=label,
        env=torch.__version__,
    )
    return t.blocked_autorange()

def bench_reduction(
    reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
    rc = []

    # Bench 2D with reduction over dim=0
    def f(t):
        return reduction_func(t, dim=0)[0]

    f.__name__ = reduction_func.__name__
    f_c = torch.compile(f, dynamic=False, fullgraph=True)

    for size in (512, 1024, 2048, 4096):
        x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
        rc_c, rc_e = f(x), f_c(x)
        rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
        rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
        rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
    return rc

def main() -> None:
    #dtypes = [torch.float16, torch.float32, torch.bfloat16, torch.int32, torch.int64]
    dtypes = [torch.float32, torch.int32, torch.int64]

    # Profile reduction ops
    rc = []
    for op, dtype in itertools.product([torch.max], dtypes):
        rc.extend(bench_reduction(op, dtype=dtype))
    Compare(rc).print()

if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 2**16
    main()
```

Produces the following table before
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      297.3      |       531.6       |       394.1       |        2550.5       |       773.0       |        4904.7       |       3647.2      |        9682.0
      max (torch.int32)    |      297.8      |       359.2       |       387.7       |        1179.4       |       768.2       |        2175.0       |       3677.1      |        4495.9
      max (torch.int64)    |      278.7      |       541.4       |       410.2       |        2873.3       |       858.9       |        5620.4       |       6107.2      |       11176.1

Times are in microseconds (us).
```
And after
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      307.9      |       265.3       |       401.0       |        340.8        |       766.5       |        661.9        |       3463.5      |        2829.5
      max (torch.int32)    |      293.5      |       263.1       |       405.0       |        338.8        |       761.4       |        672.5        |       3050.0      |        2688.6
      max (torch.int64)    |      308.2      |       255.7       |       417.4       |        341.4        |       877.0       |        695.0        |       5812.2      |        5762.2

```

`argmax`/`argmin` are much tricker due to the nan-handling logic that need to be added there.

Also fixes `torch.max/min` compilation for half-precision types, added regression types for it.

This PR also introduces a bunch of helper functions, such as `simd_broadcast` that works for int64 and `c10:🤘:pair` template, which are used by `simd_argmax` to return both value and index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158990
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-07-30 21:57:25 +00:00
1c8844d9e7 [MPS] Switch Cholesky decomp to column wise (#157014)
Everything should go thru a generalized kernels, and Metal kernels should work with the same sizes and strides as CPU or CUDA backends to avoid problems with `torch.compile` that relies on the meta kernels to tell what its ouput going to look like.

To avoid returning tensors with different layout depending on whether upper parameter is true or false, templatize `factorDiagonalBlock`, `applyTRSM` and `applySYRK` to take upper/lower (actually row-wise vs column-wise) as template argument and call appropriate templates from host

TODOs:
 - Rename upper parameter to something more sensible and add comments
 - Use simd_groupsize instead of hardcoded 32 everywhere

Fixes https://github.com/pytorch/pytorch/issues/156658

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157014
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157179
2025-07-01 18:00:59 +00:00
4ebd269065 [Testing] Remove duplicate MPSInductor tests (#157328)
They were added there before test_torchinductor were running in CI, but
now the same are covered by `GPUTests.test_pointwise_*_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157328
Approved by: https://github.com/huydhn
2025-07-01 00:21:22 +00:00
17dab018e3 [aoti][mps] Fix deduplication of kernels (#156843)
Previously I was not correctly deduplicating kernels generated by mps, so it would generate multiple of the same kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156843
Approved by: https://github.com/desertfire
2025-06-26 21:03:05 +00:00
2e65d72e1e [aoti][mps] Fix int/symint kernel args (#155583)
Integer arguments to mps kernels need to go through a different function, since `aoti_torch_mps_set_arg` only takes a Tensor. So I added a `aoti_torch_mps_set_arg_int`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155583
Approved by: https://github.com/desertfire
ghstack dependencies: #155752, #154287, #155582
2025-06-12 23:33:28 +00:00
ffbda61fbe [aoti][mps] Fix dynamic dispatch size (#155582)
In the case where we pass in a symint to the `dispatch` call, the compiler errors, so we need to cast the input to int64_t.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155582
Approved by: https://github.com/malfet
ghstack dependencies: #155752, #154287
2025-06-12 23:33:15 +00:00
a4ab392251 [aoti][mps] mps constants support (#154287)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154287
Approved by: https://github.com/malfet
ghstack dependencies: #155752
2025-06-12 23:33:07 +00:00
8821a9dc4e [BE][aoti][mps] Fix tests to use common function (#155752)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155752
Approved by: https://github.com/desertfire, https://github.com/malfet
2025-06-12 23:32:59 +00:00
da50835bde [aoti] Support c10 calls (#155256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155256
Approved by: https://github.com/malfet
2025-06-10 00:45:59 +00:00
26471fc203 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-23 05:45:35 +00:00
47a01f3efb Revert "[aoti] Initial Metal support (#153959)"
This reverts commit 28bcd9eb30336b370298dbe9677b95019882f2a8.

Reverted https://github.com/pytorch/pytorch/pull/153959 on behalf of https://github.com/angelayi due to previous PR broke frl build ([comment](https://github.com/pytorch/pytorch/pull/153959#issuecomment-2901825315))
2025-05-22 16:17:07 +00:00
28bcd9eb30 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-21 21:55:59 +00:00
694748dd9d [MPSInductor] Fix conv_transpose channels last (#153787)
Regardless of the input layout, transposed convolution always returns contiguous tensor on MPS
Add test to validate that
This fixes torch.compile for SegmentAnything network

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153787
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #153786
2025-05-19 02:01:48 +00:00
db26aeaec2 [MPSInductor] Support numpy scalars handling (#153598)
By default, numpy computes results in float64 format, but when passed as an argument to MPS function, must be implicitly converted to float32, which naturally occurs in some networks, for example in speech_transformer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153598
Approved by: https://github.com/cyyever, https://github.com/dcci
ghstack dependencies: #153582
2025-05-15 16:48:25 +00:00
470132c6a1 [MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754
Approved by: https://github.com/malfet, https://github.com/jansel
2025-04-20 17:44:40 +00:00
fa6e842527 [MPS] Make fused rms_norm traceable (#150661)
Which is a regression, introduced by https://github.com/pytorch/pytorch/issues/150629#issue-2970312779 which I should have reviewed more thoroughly.

- Defined `_fused_rms_norm`, added MPS-only implementation for it and dispatch from `rms_norm_symint`,  which is registered as `CompositeImplicitAutograd`, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops
-
- Register `_fused_rms_norm` as a fallback in `torch/_inductor/lowering.py`
- Added unit test to avoid those regressions in the future

TODO:
- Get rid of this op, change `rms_norm_symint` definition to `CompositeExplicitAutograd` and implement backward function in `tools/autograd/derivatives.yaml`
- Benchmark compiler and re-enable decomp as follows when compiled code is faster
```python
@register_decomposition(aten._rms_norm_fused)
def rms_norm_fused(
    self: torch.Tensor, ndim: int, weight: torch.Tensor, eps: float
) -> torch.Tensor:
    dtr = [self.dim() - i - 1 for i in range(ndim)]
    return self * weight * (self.pow(2).mean(dtr, keepdim=True).add(eps).rsqrt())
```

Fixes https://github.com/pytorch/pytorch/issues/150629

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150661
Approved by: https://github.com/manuelcandales, https://github.com/jansel
2025-04-17 11:32:00 +00:00
e4fe67f623 Revert "[MPS] Make fused rms_norm traceable (#150661)"
This reverts commit 682f09ec51526aefe6b504ac8081944baa866556.

Reverted https://github.com/pytorch/pytorch/pull/150661 on behalf of https://github.com/malfet due to Has decomp started to fail again ([comment](https://github.com/pytorch/pytorch/pull/150661#issuecomment-2812520408))
2025-04-17 11:06:05 +00:00
682f09ec51 [MPS] Make fused rms_norm traceable (#150661)
Which is a regression, introduced by https://github.com/pytorch/pytorch/issues/150629#issue-2970312779 which I should have reviewed more thoroughly.

- Defined `_fused_rms_norm`, added MPS-only implementation for it and dispatch from `rms_norm_symint`,  which is registered as `CompositeImplicitAutograd`, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops
-
- Register `_fused_rms_norm` as a fallback in `torch/_inductor/lowering.py`
- Added unit test to avoid those regressions in the future

TODO:
- Get rid of this op, change `rms_norm_symint` definition to `CompositeExplicitAutograd` and implement backward function in `tools/autograd/derivatives.yaml`
- Benchmark compiler and re-enable decomp as follows when compiled code is faster
```python
@register_decomposition(aten._rms_norm_fused)
def rms_norm_fused(
    self: torch.Tensor, ndim: int, weight: torch.Tensor, eps: float
) -> torch.Tensor:
    dtr = [self.dim() - i - 1 for i in range(ndim)]
    return self * weight * (self.pow(2).mean(dtr, keepdim=True).add(eps).rsqrt())
```

Fixes https://github.com/pytorch/pytorch/issues/150629

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150661
Approved by: https://github.com/manuelcandales, https://github.com/jansel
2025-04-17 04:15:24 +00:00
d7050ef48b [CI] Run test_torchinductor for MPS device (#150821)
There are only 118 failures atm, mark them all with xfail to avoid new regressions

Add `xfail_if_mps_unimplemented` decorator to distinguish between tests that call unimplemented eager op vs ones that fail for some other reason.

Added `aten._scaled_dot_product_attention_math_for_mps` fallback to make test behavior consistent between MacOS-15 (where falback is in place) and MacOS-14

Weird MacOS-14 specific skips:
- test_torchinductor.py::GPUTests::test_cat_extern_kernel_mps
- test_torchinductor.py::GPUTests::test_sort_transpose_mps (likely an eager bug)
- test_torchinductor.py::GPUTests::test_unaligned_input_mps

Numerous MacOS-13 skips, including few eager hard crashes, for example running `test_torchinductor.py::GPUTests::test_scatter5_mps` causes
```
/AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayScatter.mm:309: failed assertion `Rank of destination array (1) must be greater than or equal to inner-most dimension of indices array (3)'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150821
Approved by: https://github.com/ZainRizvi, https://github.com/dcci
ghstack dependencies: #151224, #151246, #151272, #151282, #151288
2025-04-15 18:42:39 +00:00
afaadce083 [MPSInductor] Adjust memory format detection (#151288)
MPS conv implementation will only yield channels last if input is in channels_last format
Fixes `TestGPUTests.test_conv2d_backward_channels_last` on MacOS-15

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151288
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #151224, #151246, #151272, #151282
2025-04-15 06:25:00 +00:00
070357b61a [MPSInductor] Fix silent correctness in bitcast (#151272)
By using Metal `as_type` which according to documentation does exactly
that:
> Metal adds an as_type<type-id> operator to allow any scalar or vector data type (that is not
a pointer) to be reinterpreted as another scalar or vector data type of the same size. The bits in
the operand are returned directly without modification as the new type. The usual type
promotion for function arguments is not performed.

Using `reinterpret_cast` created a potential silent correctness error when dtypes of different sizes were bitcast to each other
Add expicit cast to src_type to avoid errors due to type promotion (i.e.
soemthing like (x+1).view(dtype=torch.float16) would work correctly in
eager mode for int16 dtype, but would fail in compile, as arithmetic
operations will promote int16 to int32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151272
Approved by: https://github.com/dcci
ghstack dependencies: #151224, #151246
2025-04-14 23:39:42 +00:00
46ce8f7df6 [MPSInductor] Cast halfs to floats (#151246)
To avoid accuracy issues when small reductions are unrolled, cast half to float during the `load` op
As `op_math_t<half>` is indeed float

This fixes `test_unroll_small_reduction` for reduced precision types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151246
Approved by: https://github.com/dcci
ghstack dependencies: #151224
2025-04-14 19:47:04 +00:00
184ac8c7f7 [MPSInductor] Fix noop codegen (#151224)
By adding `pass` in front of the comment for fake set_device call
Which fixes `TestGPU.test_zero_element_mutation_mps`, which previously
failed with
```
torch._inductor.exc.InductorError: RuntimeError: Failed to import /var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmp2emka_sx/7k/c7kmnwhb363ysalhewglr3cwtej6tiz3t4ppqa4bvhubaokmlprw.py
IndentationError: expected an indented block after 'with' statement on line 38 (c7kmnwhb363ysalhewglr3cwtej6tiz3t4ppqa4bvhubaokmlprw.py, line 40)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151224
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
2025-04-14 16:38:47 +00:00
d289d1177c [CI] Fix GPUTests.test_scheduler_vertical_fusion1 (#151166)
By enabling the test_operators on MPS device

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151166
Approved by: https://github.com/dcci
2025-04-13 00:41:51 +00:00
9699cc3eb9 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 21:44:51 +00:00
7762bddd87 Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"
This reverts commit 71073caa00836c23e3fc7fcfe1d69b77ffb9d9c9.

Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
2025-04-12 20:27:48 +00:00
71073caa00 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 19:16:33 +00:00
397d37acc5 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 03:11:38 +00:00
77407b38a9 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 575f348965abe8ea428eba7098f67ec9764a7f9a.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to Linter fails again, landrace this time? ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798392241))
2025-04-12 02:22:22 +00:00
575f348965 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 00:46:01 +00:00
83f14c0b06 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 5edfb4c4fad1bb9504482d930a2540d22427d383.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
2025-04-12 00:21:14 +00:00
5edfb4c4fa [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-11 23:21:35 +00:00
b7c0fda163 [MPS] Fix determine_backend_memory_format logic (#151042)
If input is channels last than MPS will return a channels last output

This fixed `GPUTests.test_convolution_4_mps` from test_torchinductor.py

That previous failed with
```
AssertionError: expected size 3==3, stride 1==192 at dim=1; expected size 12==12, stride 48==16 at dim=2; expected size 16==16, stride 3==1 at dim=3
```
As FakeTensor implementation of conv returned `Contiguous`, rather than `ChannelLast` layout on MacOS-15 or later.
This doesn't seem to be very well documented, so will try to document the call path for `ExternKernel` invocation for `aten::convolution`:
 - First inductor decomp defined here is called
 c93e4b8290/torch/_inductor/kernel/conv.py (L424-L425)

- Then it goes thru FakeTensor decomposition implemented here
320914f1b6/torch/_subclasses/fake_impls.py (L739-L740)
- Finally it goes down to convolution meta registrations implemented here
320914f1b6/torch/_meta_registrations.py (L2416-L2417)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151042
Approved by: https://github.com/dcci
2025-04-11 01:51:34 +00:00
c830c12a87 [MPSInductor] Fix tiled reduction logic (#150737)
In case of tiles, index must include both reduction dimentions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150737
Approved by: https://github.com/dcci
2025-04-06 00:20:41 +00:00
295b7e21eb [MPS/inductor] Add support for hermite_polynomial_h. (#150664)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150664
Approved by: https://github.com/malfet
2025-04-04 13:14:52 +00:00
82ceebce58 [inductor] Lowerings for max_pool3d (#148210)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148210
Approved by: https://github.com/eellison
2025-04-02 14:13:01 +00:00
f94ac263af [MPSInductor] Fix neg for unsigned types (#150412)
By more-or-less copy-n-pasting the fix from https://github.com/pytorch/pytorch/pull/94035

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150412
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150382, #150386
2025-04-01 16:52:41 +00:00
428234bc28 [MPSInductor] torch.complex128 is unsupported on MPS (#150386)
Same as torch.float64

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150386
Approved by: https://github.com/dcci
ghstack dependencies: #150382
2025-04-01 15:19:10 +00:00
965784eb9b [MPSInductor] Specify max_total_threads_per_threadgroup (#150247)
When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size

Extend `c10:🤘:max` to accept different dtypes

Together this fixes `test_large_broadcast_reduction`

TODO:
  - Explore different threadgroup_sizes for best perf

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150246
2025-03-29 19:37:15 +00:00
6aca002d82 [MPS] Add chebyshev_polynomial_[uvw] (#150060)
For both eager and inductor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150060
Approved by: https://github.com/dcci, https://github.com/jansel
2025-03-26 23:35:05 +00:00
db8f4c1b1b [MPSInductor] Run chebyshev_polynomial_t tests (#150042)
Test name should start with `test_`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150042
Approved by: https://github.com/dcci
2025-03-26 22:50:08 +00:00
e85ce64bde [MPS/Inductor] Add support for chebyshev_polynomial_t. (#149928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149928
Approved by: https://github.com/malfet
2025-03-25 21:02:13 +00:00
66b0a0b61a [inductor] support dilation in max_pool2d lowering (#148209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148209
Approved by: https://github.com/eellison
2025-03-24 13:00:12 +00:00
2b848ab192 [MPS/inductor] Add support for modified_scaled_bessel_k{0,1} (#149794)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149794
Approved by: https://github.com/malfet
2025-03-22 15:41:40 +00:00
0ed34210b2 [MPS] Add support for modified_bessel_k1 to eager and inductor. (#149687)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149687
Approved by: https://github.com/malfet
2025-03-21 04:59:06 +00:00