Files
pytorch/test/inductor/test_mps_basic.py
Manuel Candales aea57b3aa3 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 16:06:36 +00:00

290 lines
8.4 KiB
Python

# Owner(s): ["module: mps"]
import importlib
import os
import sys
import numpy as np
import torch
from torch.testing import FileCheck, make_tensor
from torch.testing._internal.common_dtype import get_all_dtypes
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
MACOS_VERSION,
parametrize,
)
MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble] + (
[torch.bfloat16] if MACOS_VERSION < 14.0 else []
)
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
importlib.import_module("filelock")
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
check_model_gpu,
CommonTemplate,
TestCase,
)
# TODO: Remove this file.
# This tests basic MPS compile functionality
@instantiate_parametrized_tests
class MPSBasicTests(TestCase):
is_dtype_supported = CommonTemplate.is_dtype_supported
common = check_model_gpu
device = "mps"
@parametrize("dtype", MPS_DTYPES)
def test_add(self, dtype):
self.common(
lambda a, b: a + b,
(
make_tensor(1024, dtype=dtype, device=self.device),
make_tensor(1024, dtype=dtype, device=self.device),
),
check_lowp=False,
)
def test_log(self):
self.common(lambda x: x.log(), (torch.rand(1024),))
def test_acos(self):
self.common(lambda x: x.acos(), (torch.rand(1024),))
def test_atanh(self):
self.common(lambda x: x.atanh(), (torch.rand(1024),))
def test_floor(self):
self.common(lambda x: x.floor(), (torch.rand(1024),))
def test_sign(self):
self.common(lambda x: x.sign(), (torch.rand(1024),))
def test_sliced_input(self):
self.common(
lambda x: x[:, ::2].sin() + x[:, 1::2].cos(), (torch.rand(32, 1024),)
)
def test_where(self):
def foo(x):
rc = x.abs().sqrt()
rc[x < 0] = -5
return rc
self.common(foo, (torch.rand(1024),))
@parametrize("dtype", MPS_DTYPES)
def test_cast(self, dtype):
self.common(lambda a: a.to(dtype), (torch.rand(1024),))
def test_broadcast(self):
self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024)))
def test_inplace(self):
def inc_(x):
x += 1
return x
self.common(inc_, (torch.rand(1024),))
def test_rms_norm_nograd(self):
# Regression test for https://github.com/pytorch/pytorch/issues/150629
def fn(x, w):
with torch.no_grad():
return torch.nn.functional.rms_norm(x, x.shape, w)
self.common(fn, (torch.rand(10), torch.ones(10)))
def test_compile_numpy_scalar(self):
def fn(x, y):
return x / y
self.common(fn, (torch.rand(10), np.exp(0.3)))
def test_conv_transpose_channels_last(self):
def fn(x, y):
return torch.nn.functional.conv_transpose2d(x, y, stride=1, padding=1)
self.common(
fn,
(
torch.rand(1, 1, 16, 16).to(memory_format=torch.channels_last),
torch.rand(1, 4, 8, 8),
),
)
def test_conv_train(self):
# Regression test for https://github.com/pytorch/pytorch/issues/161905
def fn(x, y):
return torch.nn.functional.conv2d(x, y, None, 1, 1, 1)
self.common(
fn,
(
torch.rand(4, 512, 7, 7, requires_grad=True),
torch.rand(512, 512, 3, 3),
),
check_gradient=True,
)
def test_cholesky(self):
def fn(x):
return (
torch.linalg.cholesky(x, upper=False),
torch.linalg.cholesky(x, upper=True),
)
self.common(fn, (torch.eye(64),), check_lowp=False)
def test_reduced_max(self):
# inductor test do not validate that max of say 16K half elements can be computed
self.common(torch.max, (torch.rand(16384, dtype=torch.half),), check_lowp=False)
def test_linalg_inv(self):
def fn(x):
return torch.linalg.inv(torch.linalg.cholesky(x))
A = torch.diag(torch.tensor([20.0, 0.5, 5.0], dtype=torch.float32) ** 2)
self.common(fn, (A,), check_lowp=False)
class MPSBasicTestsAOTI(TestCase):
def check_model(self, m, inp, dynamic_shapes=None):
res2 = m(*inp)
ep = torch.export.export(m, inp, dynamic_shapes=dynamic_shapes)
path = torch._inductor.aoti_compile_and_package(ep)
m = torch._inductor.aoti_load_package(path)
res = m(*inp)
assert torch.allclose(res, res2)
def test_add_mps(self):
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
inp = (torch.ones(3, 3, device="mps"), torch.ones(3, 3, device="mps"))
m = M().to("mps")
self.check_model(m, inp)
def test_fallback_mps(self):
class M(torch.nn.Module):
def forward(self, x, y):
return torch.nn.functional.linear(x, y)
inp = (
torch.randn(10, 10, device="mps"),
torch.randn(10, 10, device="mps"),
)
m = M().to("mps")
self.check_model(m, inp)
def test_c10(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
inp = (torch.randn(2, 8, device="mps"),)
m = M().to("mps")
self.check_model(m, inp)
def test_two_const(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.y = torch.ones(3, 3, device="mps")
self.z = torch.full((3, 3), 2, device="mps")
def forward(self, x):
return x + self.y + self.z
inp = (torch.ones(3, 3, device="mps"),)
m = Model().to(device="mps")
self.check_model(m, inp)
def test_simple_dynamic(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
add_0 = x + y
return torch.nn.functional.relu(input=add_0, inplace=False)
x = torch.randn(128, 2048, device="mps")
y = torch.randn(128, 2048, device="mps")
inp = (x, y)
m = Model().to(device="mps")
dim0_x = torch.export.Dim("dim0_x", min=1, max=2048)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
self.check_model(m, inp, dynamic_shapes)
def test_dynamic_cat(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.cat([a, b], dim=0)
a = torch.randn(2, 4, device="mps")
b = torch.randn(3, 4, device="mps")
inp = (a, b)
m = Model().to(device="mps")
dim0_a = torch.export.Dim("dim0_a", min=1, max=10)
dim0_b = torch.export.Dim("dim0_b", min=1, max=20)
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
self.check_model(m, inp, dynamic_shapes)
def test_reuse_kernel(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
a = torch.sin(x)
b = torch.mm(a, y)
c = torch.sin(b)
d = torch.mm(b, c)
return d
example_inputs = (
torch.randn(87, 87, device="mps"),
torch.randn(87, 87, device="mps"),
)
model = Model()
ep = torch.export.export(model, example_inputs)
package_path = torch._export.aot_compile(ep.module(), example_inputs)
target_str = "aoti_torch_mps_get_kernel_function("
target_count = 1
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:
src_code = cpp.read()
FileCheck().check_count(
target_str,
target_count,
exactly=True,
).run(src_code)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if torch.backends.mps.is_available():
run_tests(needs="filelock")