mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
AOTI MPS Shim Implementation (#163865)
## MPS Shim API * Updated MPS shimification API with handles and function declarations: * `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types * Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function` * Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc ## MPS Shader Codegen * Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation: * **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");` * **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";` * Updated kernel call generation to use shimified functions: * Generates calls to shimified API instead of direct libtorch calls ## Before vs After Comparison ### Section 1: Shader Library **Before (Direct Library Object)** ```cpp at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL( ... )MTL"); ``` **After (Source String)** ```cpp const char* mps_lib_0_source = (R"MTL( ... )MTL"); ``` ### Section 2: Getter Functions & RAII Management **Before (Direct Library Access)** ```cpp const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() { static const auto func = mps_lib_0.getKernelFunction("generated_kernel"); return func; } AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get()); return handle; } ``` **After (Shim API + RAII Wrapper)** ```cpp AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { static auto kernel_handle = []() { AOTIMetalShaderLibraryHandle lib_handle = nullptr; AOTIMetalKernelFunctionHandle kern_handle = nullptr; aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle); aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle); // RAII wrapper with custom deleter auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{ if (h) aoti_torch_mps_delete_shader_library(h); }}; using LibDeleter = decltype(lib_deleter); using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>; // Return pair of kernel handle and library smart pointer for cleanup return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter)); }(); return kernel_handle.first; } ``` ### Section 3: Runtime Execution **Before (Direct Library Methods)** ```cpp void AOTInductorModel::run_impl(...) { ... get_mps_lib_0()->runCommandBlock([&] { get_mps_lib_0()->startEncoding(); aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0); aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1); aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1); get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)}); }); ... } // AOTInductorModel::run_impl ``` **After (Shim API with Lambda Pattern)** ```cpp void AOTInductorModel::run_impl(...) { ... auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) { aoti_torch_mps_start_encoding(handle); aoti_torch_mps_set_arg_tensor(handle, 0, buf0); aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1); aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1); aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL)); }; std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0; aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0); ... } // AOTInductorModel::run_impl ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865 Approved by: https://github.com/angelayi, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
3d1fa40ae1
commit
aea57b3aa3
@ -116,6 +116,8 @@ class MetalShaderLibrary {
|
||||
std::vector<std::string> getFunctionNames();
|
||||
std::shared_ptr<MetalKernelFunction> getKernelFunction(
|
||||
const std::string& name);
|
||||
// Returns a raw pointer to the kernel function for use in C APIs
|
||||
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
|
||||
inline MTLComputePipelineState_t getPipelineStateForFunc(
|
||||
const std::string& fname) {
|
||||
return getLibraryPipelineState(getLibrary(), fname).first;
|
||||
@ -164,6 +166,9 @@ class MetalShaderLibrary {
|
||||
std::string,
|
||||
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
|
||||
cplMap;
|
||||
// Cache for kernel functions returned by getCachedKernelFunctionPtr
|
||||
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
|
||||
kernelCache;
|
||||
};
|
||||
|
||||
class DynamicMetalShaderLibrary : public MetalShaderLibrary {
|
||||
|
@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
|
||||
return std::make_shared<MetalKernelFunction>(cpl, func);
|
||||
}
|
||||
|
||||
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
|
||||
// Check if kernel is already cached
|
||||
auto it = kernelCache.find(name);
|
||||
if (it != kernelCache.end()) {
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Create new kernel function and cache it
|
||||
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
|
||||
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
|
||||
MetalKernelFunction* raw_ptr = kernel.get();
|
||||
kernelCache[name] = std::move(kernel);
|
||||
|
||||
return raw_ptr;
|
||||
}
|
||||
|
||||
class BundledShaderLibary : public MetalShaderLibrary {
|
||||
public:
|
||||
BundledShaderLibary() : MetalShaderLibrary("") {}
|
||||
|
@ -202,7 +202,7 @@ class AOTInductorTestsTemplate:
|
||||
AOTIRunnerUtil.compile, model, example_inputs
|
||||
)
|
||||
if self.device == "mps":
|
||||
FileCheck().check("getKernelFunction(").run(code)
|
||||
FileCheck().check("aoti_torch_mps_get_kernel_function(").run(code)
|
||||
elif self.device == GPU_TYPE:
|
||||
FileCheck().check("launchKernel(").run(code)
|
||||
if config.aot_inductor.embed_kernel_binary:
|
||||
@ -2893,7 +2893,7 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
if self.device == "mps":
|
||||
self.code_check_count(
|
||||
model, example_inputs, '.getKernelFunction("generated_kernel")', 1
|
||||
model, example_inputs, "aoti_torch_mps_get_kernel_function(", 1
|
||||
)
|
||||
elif self.device == GPU_TYPE:
|
||||
self.code_check_count(
|
||||
|
@ -270,7 +270,7 @@ class MPSBasicTestsAOTI(TestCase):
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
package_path = torch._export.aot_compile(ep.module(), example_inputs)
|
||||
|
||||
target_str = 'mps_lib_0.getKernelFunction("generated_kernel")'
|
||||
target_str = "aoti_torch_mps_get_kernel_function("
|
||||
target_count = 1
|
||||
|
||||
with open(os.path.splitext(package_path)[0] + ".cpp") as cpp:
|
||||
|
@ -20,6 +20,7 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._used_kernel_names: OrderedSet[str] = OrderedSet()
|
||||
self._lambda_counter: int = 0
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -47,13 +48,16 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
"""
|
||||
Generates MPS kernel call code. It should look something like:
|
||||
```
|
||||
get_mps_lib_0()->runCommandBlock([&] {
|
||||
get_mps_lib_0()->startEncoding();
|
||||
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0);
|
||||
aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 1, arg0_1);
|
||||
...
|
||||
get_mps_lib_0()->dispatch(9);
|
||||
});
|
||||
auto mps_lib_0_lambda = [&](AOTIMetalKernelFunctionHandle handle) {
|
||||
aoti_torch_mps_start_encoding(handle);
|
||||
aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
|
||||
aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
|
||||
aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
|
||||
aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
|
||||
};
|
||||
|
||||
std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper = mps_lib_0_lambda;
|
||||
aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper);
|
||||
```
|
||||
"""
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
@ -78,13 +82,9 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
new_args = []
|
||||
for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])):
|
||||
if isinstance(arg_type, torch.dtype):
|
||||
new_args.append(
|
||||
f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});"
|
||||
)
|
||||
new_args.append(f"aoti_torch_mps_set_arg_tensor(handle, {idx}, {arg});")
|
||||
elif arg_type in (int, sympy.core.symbol.Symbol):
|
||||
new_args.append(
|
||||
f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});"
|
||||
)
|
||||
new_args.append(f"aoti_torch_mps_set_arg_int(handle, {idx}, {arg});")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}"
|
||||
@ -93,12 +93,85 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
threads, group_size = call_args[-2], call_args[-1]
|
||||
if threads is None:
|
||||
raise NotImplementedError("No threads or group_size provided")
|
||||
elif group_size is None:
|
||||
new_args.append(f"get_{kernel_name}()->dispatch({threads});\n")
|
||||
|
||||
# Check if threads is a single value or an array-like structure
|
||||
threads_str = str(threads)
|
||||
is_single_value = (
|
||||
threads_str.startswith("{")
|
||||
and threads_str.endswith("}")
|
||||
and threads_str.count(",") == 0
|
||||
) or not threads_str.startswith(("{", "["))
|
||||
|
||||
if is_single_value:
|
||||
# Extract single value from braces if present
|
||||
if threads_str.startswith("{") and threads_str.endswith("}"):
|
||||
single_value = threads_str[1:-1].strip() # Remove braces
|
||||
else:
|
||||
single_value = threads_str
|
||||
|
||||
if group_size is None:
|
||||
new_args.append(
|
||||
f"aoti_torch_mps_dispatch_single(handle, {single_value});"
|
||||
)
|
||||
else:
|
||||
# Extract group size value if it's also in braces
|
||||
group_size_str = str(group_size)
|
||||
if group_size_str.startswith("{") and group_size_str.endswith("}"):
|
||||
group_size_value = group_size_str[1:-1].strip()
|
||||
else:
|
||||
group_size_value = group_size_str
|
||||
new_args.append(
|
||||
f"aoti_torch_mps_dispatch_single_with_group_size(handle, {single_value}, {group_size_value});"
|
||||
)
|
||||
else:
|
||||
new_args.append(
|
||||
f"get_{kernel_name}()->dispatch({threads}, {group_size});\n"
|
||||
)
|
||||
# Handle array case - need to convert initializer list to array
|
||||
# Use kernel name to make variable names unique
|
||||
threads_var = f"{kernel_name}_threads_array"
|
||||
group_size_var = f"{kernel_name}_group_size_array"
|
||||
|
||||
# Extract array size from the initializer list string
|
||||
def get_array_size(array_str: str) -> int:
|
||||
# Remove braces and whitespace
|
||||
content = array_str.strip()
|
||||
if content.startswith("{") and content.endswith("}"):
|
||||
content = content[1:-1].strip()
|
||||
|
||||
if not content: # Empty array
|
||||
return 0
|
||||
|
||||
# Count elements by counting commas, accounting for nested structures
|
||||
depth = 0
|
||||
comma_count = 0
|
||||
for char in content:
|
||||
if char in "({[<":
|
||||
depth += 1
|
||||
elif char in ")}]>":
|
||||
depth -= 1
|
||||
elif char == "," and depth == 0:
|
||||
comma_count += 1
|
||||
|
||||
return comma_count + 1 # Number of elements = commas + 1
|
||||
|
||||
threads_size = get_array_size(threads_str)
|
||||
|
||||
if group_size is None:
|
||||
new_args.append("{")
|
||||
new_args.append(f" uint64_t {threads_var}[] = {threads};")
|
||||
new_args.append(
|
||||
f" aoti_torch_mps_dispatch_array(handle, {threads_var}, {threads_size});"
|
||||
)
|
||||
new_args.append("}")
|
||||
else:
|
||||
group_size_str = str(group_size)
|
||||
group_size_size = get_array_size(group_size_str)
|
||||
new_args.append("{")
|
||||
new_args.append(f" uint64_t {threads_var}[] = {threads};")
|
||||
new_args.append(f" uint64_t {group_size_var}[] = {group_size};")
|
||||
dispatch_args = f"handle, {threads_var}, {threads_size}, {group_size_var}, {group_size_size}"
|
||||
new_args.append(
|
||||
f" aoti_torch_mps_dispatch_array_with_group_size({dispatch_args});"
|
||||
)
|
||||
new_args.append("}")
|
||||
|
||||
# debug printer related logic for cpp kernel type.
|
||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||
@ -113,14 +186,34 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
self.write_mps_kernel_call(kernel_name, new_args)
|
||||
|
||||
def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None:
|
||||
# Initialization of the kernel function and kernel function handle
|
||||
# variables have already been done at the beginning, which was
|
||||
# codegen-ed in `codegen_mps_func_init`
|
||||
self.writeline(f"get_{name}()->runCommandBlock([&] {{")
|
||||
self.writeline(f" get_{name}()->startEncoding();")
|
||||
# Generate unique variable names to avoid duplicate declarations
|
||||
# when the same MPS lib is used multiple times
|
||||
unique_suffix = self._lambda_counter
|
||||
self._lambda_counter += 1
|
||||
|
||||
lambda_name = f"{name}_lambda_{unique_suffix}"
|
||||
wrapper_name = f"{name}_func_wrapper_{unique_suffix}"
|
||||
|
||||
# Generate the function call code (in current location)
|
||||
# Create lambda that captures by reference and pass its pointer through void*
|
||||
self.writeline(
|
||||
f"auto {lambda_name} = [&](AOTIMetalKernelFunctionHandle handle) {{"
|
||||
)
|
||||
self.writeline(" aoti_torch_mps_start_encoding(handle);")
|
||||
|
||||
# Output call args directly since we're capturing by reference
|
||||
for call_arg in call_args:
|
||||
self.writeline(f" {call_arg}")
|
||||
self.writeline("});")
|
||||
self.writeline("};")
|
||||
self.writeline("")
|
||||
|
||||
# Pass lambda pointer through void*
|
||||
self.writeline(
|
||||
f"std::function<void(AOTIMetalKernelFunctionHandle)> {wrapper_name} = {lambda_name};"
|
||||
)
|
||||
self.writeline(
|
||||
f"aoti_torch_mps_run_command_block(get_{name}_handle(), aoti_torch_mps_shared_callback, &{wrapper_name});"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_device_include_path(device: str) -> str:
|
||||
@ -132,49 +225,77 @@ class CppWrapperMps(CppWrapperGpu):
|
||||
|
||||
def codegen_additional_funcs(self) -> None:
|
||||
"""
|
||||
We want to codegen the mps kernel function variable initializations
|
||||
ahead of time. This is so that if we reuse kernels within subgraphs, we
|
||||
don't need to worry about the scope in which we're initializing the
|
||||
variables. Instead we will just initialize the variables all at the top
|
||||
level.
|
||||
Generate thread-safe lazy singleton pattern for MPS shader libraries with RAII cleanup.
|
||||
|
||||
The kernel function variable initializations should look something like:
|
||||
The generated code will look like:
|
||||
```
|
||||
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
|
||||
static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
|
||||
return func;
|
||||
}
|
||||
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
|
||||
static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
|
||||
return handle;
|
||||
static auto kernel_handle = []() {
|
||||
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
|
||||
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
|
||||
|
||||
aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
|
||||
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
|
||||
|
||||
// RAII wrapper with custom deleter
|
||||
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {
|
||||
if (h) aoti_torch_mps_delete_shader_library(h);
|
||||
};
|
||||
|
||||
using LibDeleter = decltype(lib_deleter);
|
||||
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
|
||||
|
||||
// Return pair of kernel handle and library smart pointer for cleanup
|
||||
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
|
||||
}();
|
||||
return kernel_handle.first;
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# Add shimified handles and functions
|
||||
shader_libraries: OrderedSet[str] = OrderedSet()
|
||||
for line in self.lines:
|
||||
if not isinstance(line, KernelCallLine):
|
||||
continue
|
||||
if line.device.type != "mps":
|
||||
continue
|
||||
|
||||
# Only add handle definition once
|
||||
# Extract library name from kernel name (e.g., "mps_lib_0" from kernel calls)
|
||||
if line.kernel_name not in self._used_kernel_names:
|
||||
self._used_kernel_names.add(line.kernel_name)
|
||||
shader_libraries.add(line.kernel_name)
|
||||
|
||||
self.prefix.writeline(
|
||||
f"const std::shared_ptr<at::native::mps::MetalKernelFunction> get_{line.kernel_name}() {{"
|
||||
)
|
||||
self.prefix.writeline(
|
||||
f' static const auto func = {line.kernel_name}.getKernelFunction("generated_kernel");'
|
||||
)
|
||||
self.prefix.writeline(" return func;")
|
||||
self.prefix.writeline("}")
|
||||
# NOTE: For shimified version, we expect the shader source constant to be generated
|
||||
# by the existing MPS shader generation process, but instead of instantiating the
|
||||
# DynamicMetalShaderLibrary directly, we'll use our shim functions.
|
||||
# The existing codegen should produce something like:
|
||||
# const char* mps_lib_0_source = R"MTL(...shader_source...)MTL";
|
||||
# instead of:
|
||||
# at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...shader_source...)MTL");
|
||||
|
||||
self.prefix.writeline(
|
||||
f"AOTIMetalKernelFunctionHandle get_{line.kernel_name}_handle() {{"
|
||||
)
|
||||
self.prefix.writeline(
|
||||
f" static const auto handle = AOTIMetalKernelFunctionHandle(get_{line.kernel_name}().get());"
|
||||
)
|
||||
self.prefix.writeline(" return handle;")
|
||||
self.prefix.writeline("}")
|
||||
# Generate thread-safe lazy singleton with RAII for each library
|
||||
for lib_name in shader_libraries:
|
||||
self.prefix.splice(f"""
|
||||
AOTIMetalKernelFunctionHandle get_{lib_name}_handle() {{
|
||||
static auto kernel_handle = []() {{
|
||||
AOTIMetalShaderLibraryHandle lib_handle = nullptr;
|
||||
AOTIMetalKernelFunctionHandle kern_handle = nullptr;
|
||||
|
||||
aoti_torch_mps_create_shader_library({lib_name}_source, &lib_handle);
|
||||
aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);
|
||||
|
||||
// RAII wrapper with custom deleter
|
||||
auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
|
||||
if (h) aoti_torch_mps_delete_shader_library(h);
|
||||
}};
|
||||
|
||||
using LibDeleter = decltype(lib_deleter);
|
||||
using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;
|
||||
|
||||
// Return pair of kernel handle and library smart pointer for cleanup
|
||||
return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
|
||||
}}();
|
||||
return kernel_handle.first;
|
||||
}}
|
||||
""")
|
||||
|
@ -1058,10 +1058,8 @@ class MetalScheduling(SIMDScheduling):
|
||||
wrapper.src_to_kernel[src_code] = kernel_name
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
src_code = (
|
||||
f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}"
|
||||
+ src_code
|
||||
)
|
||||
# For shimified version, generate source constant instead of direct instantiation
|
||||
src_code = f"const char* {mps_lib_name}_source = " + src_code
|
||||
|
||||
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
metadata_comment = f"{origins}\n{detailed_origins}"
|
||||
|
@ -3,12 +3,32 @@
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
struct AOTIMetalKernelFunctionOpaque;
|
||||
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
|
||||
|
||||
struct AOTIMetalShaderLibraryOpaque;
|
||||
using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct AOTIMetalKernelFunctionOpaque;
|
||||
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
|
||||
// MetalShaderLibrary functions
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_create_shader_library(
|
||||
const char* metal_shader_source,
|
||||
AOTIMetalShaderLibraryHandle* library_handle);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_delete_shader_library(
|
||||
AOTIMetalShaderLibraryHandle library_handle);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_get_kernel_function(
|
||||
AOTIMetalShaderLibraryHandle library_handle,
|
||||
const char* kernel_name,
|
||||
AOTIMetalKernelFunctionHandle* function_handle);
|
||||
|
||||
// MetalKernelFunction functions
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_mps_start_encoding(AOTIMetalKernelFunctionHandle func);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
@ -20,6 +40,27 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_int(
|
||||
unsigned idx,
|
||||
int64_t val);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
uint64_t length);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
uint64_t length,
|
||||
uint64_t group_size);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
const uint64_t* length,
|
||||
size_t length_size);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
const uint64_t* length,
|
||||
size_t length_size,
|
||||
const uint64_t* group_size,
|
||||
size_t group_size_size);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_mps_malloc(void** buffer, size_t num_bytes);
|
||||
|
||||
@ -39,6 +80,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer(
|
||||
size_t src_offset,
|
||||
size_t dst_offset);
|
||||
|
||||
// C callback function type for command block execution
|
||||
typedef void (*aoti_torch_mps_command_block_callback_t)(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
void* user_data);
|
||||
|
||||
// Shared callback function for std::function trampoline
|
||||
AOTI_TORCH_EXPORT void aoti_torch_mps_shared_callback(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
void* user_data);
|
||||
|
||||
// Pure C version using function pointer and user data for trampoline pattern
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_run_command_block(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
aoti_torch_mps_command_block_callback_t callback,
|
||||
void* user_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
@ -27,3 +27,116 @@ AOTITorchError aoti_torch_mps_set_arg_int(
|
||||
func->setArg(idx, val);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_create_shader_library(
|
||||
const char* metal_shader_source,
|
||||
AOTIMetalShaderLibraryHandle* library_handle) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* library = new at::native::mps::DynamicMetalShaderLibrary(
|
||||
std::string(metal_shader_source));
|
||||
*library_handle = reinterpret_cast<AOTIMetalShaderLibraryHandle>(library);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_delete_shader_library(
|
||||
AOTIMetalShaderLibraryHandle library_handle) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* library =
|
||||
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
|
||||
delete library;
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_get_kernel_function(
|
||||
AOTIMetalShaderLibraryHandle library_handle,
|
||||
const char* kernel_name,
|
||||
AOTIMetalKernelFunctionHandle* function_handle) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* library =
|
||||
reinterpret_cast<at::native::mps::MetalShaderLibrary*>(library_handle);
|
||||
auto* function =
|
||||
library->getCachedKernelFunctionPtr(std::string(kernel_name));
|
||||
*function_handle =
|
||||
reinterpret_cast<AOTIMetalKernelFunctionHandle>(function);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_start_encoding(
|
||||
AOTIMetalKernelFunctionHandle func) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* function_ptr =
|
||||
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||
function_ptr->startEncoding();
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_dispatch_single(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
uint64_t length) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* function_ptr =
|
||||
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||
function_ptr->dispatch(length);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
uint64_t length,
|
||||
uint64_t group_size) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* function_ptr =
|
||||
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||
function_ptr->dispatch(length, group_size);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_dispatch_array(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
const uint64_t* length,
|
||||
size_t length_size) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* function_ptr =
|
||||
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||
c10::ArrayRef<uint64_t> length_ref(length, length_size);
|
||||
function_ptr->dispatch(length_ref);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
const uint64_t* length,
|
||||
size_t length_size,
|
||||
const uint64_t* group_size,
|
||||
size_t group_size_size) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* function_ptr =
|
||||
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||
c10::ArrayRef<uint64_t> length_ref(length, length_size);
|
||||
c10::ArrayRef<uint64_t> group_size_ref(group_size, group_size_size);
|
||||
function_ptr->dispatch(length_ref, group_size_ref);
|
||||
});
|
||||
}
|
||||
|
||||
// Shared callback function for std::function trampoline
|
||||
void aoti_torch_mps_shared_callback(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
void* user_data) {
|
||||
auto* function_wrapper =
|
||||
static_cast<std::function<void(AOTIMetalKernelFunctionHandle)>*>(
|
||||
user_data);
|
||||
(*function_wrapper)(func);
|
||||
}
|
||||
|
||||
// Pure C version using function pointer and user data for trampoline pattern
|
||||
AOTITorchError aoti_torch_mps_run_command_block(
|
||||
AOTIMetalKernelFunctionHandle func,
|
||||
aoti_torch_mps_command_block_callback_t callback,
|
||||
void* user_data) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto* function_ptr =
|
||||
reinterpret_cast<at::native::mps::MetalKernelFunction*>(func);
|
||||
function_ptr->runCommandBlock(
|
||||
[callback, func, user_data]() { callback(func, user_data); });
|
||||
});
|
||||
}
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <ATen/native/mps/MetalShaderLibrary.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim_mps.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
@ -6,7 +5,6 @@
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
|
||||
|
||||
using namespace torch::aot_inductor;
|
||||
|
||||
AOTITorchError aoti_torch_mps_malloc(
|
||||
@ -33,7 +31,6 @@ AOTITorchError aoti_torch_mps_free(
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
AOTITorchError
|
||||
aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
@ -46,7 +43,6 @@ aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, s
|
||||
AOTITorchError
|
||||
aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
|
||||
auto src_mtl_buffer = (id<MTLBuffer>)src_buffer;
|
||||
auto dst_mtl_buffer = (id<MTLBuffer>)dst_buffer;
|
||||
|
||||
|
Reference in New Issue
Block a user