AOTI MPS Shim Implementation (#163865)

## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
This commit is contained in:
Manuel Candales
2025-10-09 16:06:36 +00:00
committed by PyTorch MergeBot
parent 3d1fa40ae1
commit aea57b3aa3
9 changed files with 372 additions and 66 deletions

View File

@ -116,6 +116,8 @@ class MetalShaderLibrary {
std::vector<std::string> getFunctionNames();
std::shared_ptr<MetalKernelFunction> getKernelFunction(
const std::string& name);
// Returns a raw pointer to the kernel function for use in C APIs
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
@ -164,6 +166,9 @@ class MetalShaderLibrary {
std::string,
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap;
// Cache for kernel functions returned by getCachedKernelFunctionPtr
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
kernelCache;
};
class DynamicMetalShaderLibrary : public MetalShaderLibrary {

View File

@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
return std::make_shared<MetalKernelFunction>(cpl, func);
}
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
// Check if kernel is already cached
auto it = kernelCache.find(name);
if (it != kernelCache.end()) {
return it->second.get();
}
// Create new kernel function and cache it
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
MetalKernelFunction* raw_ptr = kernel.get();
kernelCache[name] = std::move(kernel);
return raw_ptr;
}
class BundledShaderLibary : public MetalShaderLibrary {
public:
BundledShaderLibary() : MetalShaderLibrary("") {}

View File

@ -202,7 +202,7 @@ class AOTInductorTestsTemplate:
AOTIRunnerUtil.compile, model, example_inputs
)
if self.device == "mps":
FileCheck().check("getKernelFunction(").run(code)
FileCheck().check("aoti_torch_mps_get_kernel_function(").run(code)
elif self.device == GPU_TYPE:
FileCheck().check("launchKernel(").run(code)
if config.aot_inductor.embed_kernel_binary:
@ -2893,7 +2893,7 @@ class AOTInductorTestsTemplate:
if self.device == "mps":
self.code_check_count(
model, example_inputs, '.getKernelFunction("generated_kernel")', 1
model, example_inputs, "aoti_torch_mps_get_kernel_function(", 1
)
elif self.device == GPU_TYPE:
self.code_check_count(

View File

@ -270,7 +270,7 @@ class MPSBasicTestsAOTI(TestCase):
ep = torch.export.export(model, example_inputs)
package_path = torch._export.aot_compile(ep.module(), example_inputs)
target_str = 'mps_lib_0.getKernelFunction("generated_kernel")'
target_str = "aoti_torch_mps_get_kernel_function("
target_count = 1
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:

View File

@ -20,6 +20,7 @@ class CppWrapperMps(CppWrapperGpu):
def __init__(self) -> None:
super().__init__()
self._used_kernel_names: OrderedSet[str] = OrderedSet()
self._lambda_counter: int = 0
@staticmethod
def create(
@ -47,13 +48,16 @@ class CppWrapperMps(CppWrapperGpu):
"""
Generates MPS kernel call code. It should look something like:
```
get_mps_lib_0()->runCommandBlock([&] {
get_mps_lib_0()->startEncoding();
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0);
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 1, arg0_1);
...
get_mps_lib_0()->dispatch(9);
});
auto mps_lib_0_lambda = [&](AOTIMetalKernelFunctionHandle handle) {
aoti_torch_mps_start_encoding(handle);
aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
};
std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper = mps_lib_0_lambda;
aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper);
```
"""
device = device or V.graph.get_current_device_or_throw()
@ -78,13 +82,9 @@ class CppWrapperMps(CppWrapperGpu):
new_args = []
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
if isinstance(arg_type, torch.dtype):
new_args.append(
f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});"
)
new_args.append(f"aoti_torch_mps_set_arg_tensor(handle, {idx}, {arg});")
elif arg_type in (int, sympy.core.symbol.Symbol):
new_args.append(
f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});"
)
new_args.append(f"aoti_torch_mps_set_arg_int(handle, {idx}, {arg});")
else:
raise NotImplementedError(
f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
@ -93,12 +93,85 @@ class CppWrapperMps(CppWrapperGpu):
threads, group_size = call_args[-2], call_args[-1]
if threads is None:
raise NotImplementedError("No threads or group_size provided")
elif group_size is None:
new_args.append(f"get_{kernel_name}()->dispatch({threads});\n")
# Check if threads is a single value or an array-like structure
threads_str = str(threads)
is_single_value = (
threads_str.startswith("{")
and threads_str.endswith("}")
and threads_str.count(",") == 0
) or not threads_str.startswith(("{", "["))
if is_single_value:
# Extract single value from braces if present
if threads_str.startswith("{") and threads_str.endswith("}"):
single_value = threads_str[1:-1].strip() # Remove braces
else:
single_value = threads_str
if group_size is None:
new_args.append(
f"aoti_torch_mps_dispatch_single(handle, {single_value});"
)
else:
# Extract group size value if it's also in braces
group_size_str = str(group_size)
if group_size_str.startswith("{") and group_size_str.endswith("}"):
group_size_value = group_size_str[1:-1].strip()
else:
group_size_value = group_size_str
new_args.append(
f"aoti_torch_mps_dispatch_single_with_group_size(handle, {single_value}, {group_size_value});"
)
else:
new_args.append(
f"get_{kernel_name}()->dispatch({threads}, {group_size});\n"
)
# Handle array case - need to convert initializer list to array
# Use kernel name to make variable names unique
threads_var = f"{kernel_name}_threads_array"
group_size_var = f"{kernel_name}_group_size_array"
# Extract array size from the initializer list string
def get_array_size(array_str: str) -> int:
# Remove braces and whitespace
content = array_str.strip()
if content.startswith("{") and content.endswith("}"):
content = content[1:-1].strip()
if not content: # Empty array
return 0
# Count elements by counting commas, accounting for nested structures
depth = 0
comma_count = 0
for char in content:
if char in "({[<":
depth += 1
elif char in ")}]>":
depth -= 1
elif char == "," and depth == 0:
comma_count += 1
return comma_count + 1 # Number of elements = commas + 1
threads_size = get_array_size(threads_str)
if group_size is None:
new_args.append("{")
new_args.append(f" uint64_t {threads_var}[] = {threads};")
new_args.append(
f" aoti_torch_mps_dispatch_array(handle, {threads_var}, {threads_size});"
)
new_args.append("}")
else:
group_size_str = str(group_size)
group_size_size = get_array_size(group_size_str)
new_args.append("{")
new_args.append(f" uint64_t {threads_var}[] = {threads};")
new_args.append(f" uint64_t {group_size_var}[] = {group_size};")
dispatch_args = f"handle, {threads_var}, {threads_size}, {group_size_var}, {group_size_size}"
new_args.append(
f" aoti_torch_mps_dispatch_array_with_group_size({dispatch_args});"
)
new_args.append("}")
# debug printer related logic for cpp kernel type.
debug_printer_manager = V.graph.wrapper_code.debug_printer
@ -113,14 +186,34 @@ class CppWrapperMps(CppWrapperGpu):
self.write_mps_kernel_call(kernel_name, new_args)
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
# Initialization of the kernel function and kernel function handle
# variables have already been done at the beginning, which was
# codegen-ed in `codegen_mps_func_init`
self.writeline(f"get_{name}()->runCommandBlock([&] {{")
self.writeline(f" get_{name}()->startEncoding();")
# Generate unique variable names to avoid duplicate declarations
# when the same MPS lib is used multiple times
unique_suffix = self._lambda_counter
self._lambda_counter += 1
lambda_name = f"{name}_lambda_{unique_suffix}"
wrapper_name = f"{name}_func_wrapper_{unique_suffix}"
# Generate the function call code (in current location)
# Create lambda that captures by reference and pass its pointer through void*
self.writeline(
f"auto {lambda_name} = [&](AOTIMetalKernelFunctionHandle handle) {{"
)
self.writeline(" aoti_torch_mps_start_encoding(handle);")
# Output call args directly since we're capturing by reference
for call_arg in call_args:
self.writeline(f" {call_arg}")
self.writeline("});")
self.writeline("};")
self.writeline("")
# Pass lambda pointer through void*
self.writeline(
f"std::function<void(AOTIMetalKernelFunctionHandle)> {wrapper_name} = {lambda_name};"
)
self.writeline(
f"aoti_torch_mps_run_command_block(get_{name}_handle(), aoti_torch_mps_shared_callback, &{wrapper_name});"
)
@staticmethod
def get_device_include_path(device: str) -> str:
@ -132,49 +225,77 @@ class CppWrapperMps(CppWrapperGpu):
def codegen_additional_funcs(self) -> None:
"""
We want to codegen the mps kernel function variable initializations
ahead of time. This is so that if we reuse kernels within subgraphs, we
don't need to worry about the scope in which we're initializing the
variables. Instead we will just initialize the variables all at the top
level.
Generate thread-safe lazy singleton pattern for MPS shader libraries with RAII cleanup.
The kernel function variable initializations should look something like:
The generated code will look like:
```
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
return func;
}
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
return handle;
static auto kernel_handle = []() {
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
// RAII wrapper with custom deleter
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {
if (h) aoti_torch_mps_delete_shader_library(h);
};
using LibDeleter = decltype(lib_deleter);
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
// Return pair of kernel handle and library smart pointer for cleanup
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
}();
return kernel_handle.first;
}
```
"""
# Add shimified handles and functions
shader_libraries: OrderedSet[str] = OrderedSet()
for line in self.lines:
if not isinstance(line, KernelCallLine):
continue
if line.device.type != "mps":
continue
# Only add handle definition once
# Extract library name from kernel name (e.g., "mps_lib_0" from kernel calls)
if line.kernel_name not in self._used_kernel_names:
self._used_kernel_names.add(line.kernel_name)
shader_libraries.add(line.kernel_name)
self.prefix.writeline(
f"const std::shared_ptr<at::native::mps::MetalKernelFunction> get_{line.kernel_name}() {{"
)
self.prefix.writeline(
f' static const auto func = {line.kernel_name}.getKernelFunction("generated_kernel");'
)
self.prefix.writeline(" return func;")
self.prefix.writeline("}")
# NOTE: For shimified version, we expect the shader source constant to be generated
# by the existing MPS shader generation process, but instead of instantiating the
# DynamicMetalShaderLibrary directly, we'll use our shim functions.
# The existing codegen should produce something like:
# const char* mps_lib_0_source = R"MTL(...shader_source...)MTL";
# instead of:
# at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...shader_source...)MTL");
self.prefix.writeline(
f"AOTIMetalKernelFunctionHandle get_{line.kernel_name}_handle() {{"
)
self.prefix.writeline(
f" static const auto handle = AOTIMetalKernelFunctionHandle(get_{line.kernel_name}().get());"
)
self.prefix.writeline(" return handle;")
self.prefix.writeline("}")
# Generate thread-safe lazy singleton with RAII for each library
for lib_name in shader_libraries:
self.prefix.splice(f"""
AOTIMetalKernelFunctionHandle get_{lib_name}_handle() {{
static auto kernel_handle = []() {{
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
aoti_torch_mps_create_shader_library({lib_name}_source, &lib_handle);
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
// RAII wrapper with custom deleter
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
if (h) aoti_torch_mps_delete_shader_library(h);
}};
using LibDeleter = decltype(lib_deleter);
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
// Return pair of kernel handle and library smart pointer for cleanup
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
}}();
return kernel_handle.first;
}}
""")

View File

@ -1058,10 +1058,8 @@ class MetalScheduling(SIMDScheduling):
wrapper.src_to_kernel[src_code] = kernel_name
if V.graph.cpp_wrapper:
src_code = (
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}"
+ src_code
)
# For shimified version, generate source constant instead of direct instantiation
src_code = f"const char* {mps_lib_name}_source = " + src_code
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
metadata_comment = f"{origins}\n{detailed_origins}"

View File

@ -3,12 +3,32 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
struct AOTIMetalKernelFunctionOpaque;
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
struct AOTIMetalShaderLibraryOpaque;
using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*;
#ifdef __cplusplus
extern "C" {
#endif
struct AOTIMetalKernelFunctionOpaque;
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
// MetalShaderLibrary functions
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_create_shader_library(
const char* metal_shader_source,
AOTIMetalShaderLibraryHandle* library_handle);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_delete_shader_library(
AOTIMetalShaderLibraryHandle library_handle);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_get_kernel_function(
AOTIMetalShaderLibraryHandle library_handle,
const char* kernel_name,
AOTIMetalKernelFunctionHandle* function_handle);
// MetalKernelFunction functions
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_mps_start_encoding(AOTIMetalKernelFunctionHandle func);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor(
AOTIMetalKernelFunctionHandle func,
@ -20,6 +40,27 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_int(
unsigned idx,
int64_t val);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single(
AOTIMetalKernelFunctionHandle func,
uint64_t length);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
AOTIMetalKernelFunctionHandle func,
uint64_t length,
uint64_t group_size);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size,
const uint64_t* group_size,
size_t group_size_size);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
@ -39,6 +80,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer(
size_t src_offset,
size_t dst_offset);
// C callback function type for command block execution
typedef void (*aoti_torch_mps_command_block_callback_t)(
AOTIMetalKernelFunctionHandle func,
void* user_data);
// Shared callback function for std::function trampoline
AOTI_TORCH_EXPORT void aoti_torch_mps_shared_callback(
AOTIMetalKernelFunctionHandle func,
void* user_data);
// Pure C version using function pointer and user data for trampoline pattern
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_run_command_block(
AOTIMetalKernelFunctionHandle func,
aoti_torch_mps_command_block_callback_t callback,
void* user_data);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -27,3 +27,116 @@ AOTITorchError aoti_torch_mps_set_arg_int(
func->setArg(idx, val);
});
}
AOTITorchError aoti_torch_mps_create_shader_library(
const char* metal_shader_source,
AOTIMetalShaderLibraryHandle* library_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* library = new at::native::mps::DynamicMetalShaderLibrary(
std::string(metal_shader_source));
*library_handle = reinterpret_cast<AOTIMetalShaderLibraryHandle>(library);
});
}
AOTITorchError aoti_torch_mps_delete_shader_library(
AOTIMetalShaderLibraryHandle library_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* library =
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
delete library;
});
}
AOTITorchError aoti_torch_mps_get_kernel_function(
AOTIMetalShaderLibraryHandle library_handle,
const char* kernel_name,
AOTIMetalKernelFunctionHandle* function_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* library =
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
auto* function =
library->getCachedKernelFunctionPtr(std::string(kernel_name));
*function_handle =
reinterpret_cast<AOTIMetalKernelFunctionHandle>(function);
});
}
AOTITorchError aoti_torch_mps_start_encoding(
AOTIMetalKernelFunctionHandle func) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->startEncoding();
});
}
AOTITorchError aoti_torch_mps_dispatch_single(
AOTIMetalKernelFunctionHandle func,
uint64_t length) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->dispatch(length);
});
}
AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
AOTIMetalKernelFunctionHandle func,
uint64_t length,
uint64_t group_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->dispatch(length, group_size);
});
}
AOTITorchError aoti_torch_mps_dispatch_array(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
c10::ArrayRef<uint64_t> length_ref(length, length_size);
function_ptr->dispatch(length_ref);
});
}
AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
AOTIMetalKernelFunctionHandle func,
const uint64_t* length,
size_t length_size,
const uint64_t* group_size,
size_t group_size_size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
c10::ArrayRef<uint64_t> length_ref(length, length_size);
c10::ArrayRef<uint64_t> group_size_ref(group_size, group_size_size);
function_ptr->dispatch(length_ref, group_size_ref);
});
}
// Shared callback function for std::function trampoline
void aoti_torch_mps_shared_callback(
AOTIMetalKernelFunctionHandle func,
void* user_data) {
auto* function_wrapper =
static_cast<std::function<void(AOTIMetalKernelFunctionHandle)>*>(
user_data);
(*function_wrapper)(func);
}
// Pure C version using function pointer and user data for trampoline pattern
AOTITorchError aoti_torch_mps_run_command_block(
AOTIMetalKernelFunctionHandle func,
aoti_torch_mps_command_block_callback_t callback,
void* user_data) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* function_ptr =
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
function_ptr->runCommandBlock(
[callback, func, user_data]() { callback(func, user_data); });
});
}

View File

@ -1,4 +1,3 @@
#include <ATen/native/mps/MetalShaderLibrary.h>
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <ATen/mps/MPSAllocatorInterface.h>
@ -6,7 +5,6 @@
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSProfiler.h>
using namespace torch::aot_inductor;
AOTITorchError aoti_torch_mps_malloc(
@ -33,7 +31,6 @@ AOTITorchError aoti_torch_mps_free(
});
}
AOTITorchError
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
@ -46,7 +43,6 @@ aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, s
AOTITorchError
aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto src_mtl_buffer = (id<MTLBuffer>)src_buffer;
auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer;